Add versioning based on input/output types in toco.

To make sure we keep toco versioning as a single source of versioning information. Dequantize is the ops that support incremented version based on the int8 type, currently but the hybrid ops will soon follow.

PiperOrigin-RevId: 225958674
This commit is contained in:
Suharsh Sivakumar 2018-12-18 01:35:35 -08:00 committed by TensorFlower Gardener
parent 5a02334ec9
commit e4bdb31636
8 changed files with 279 additions and 106 deletions

View File

@ -376,7 +376,7 @@ struct Operator {
// Output activation arrays. Same comments as for inputs apply here too.
std::vector<string> outputs;
// If true, the array has more outputs than are listed in the 'outputs'
// If true, the operator has more outputs than are listed in the 'outputs'
// member. These need to be resolved by some graph transformation.
// This flag is only here to indicate that an operator should not be
// discarded as unused, even if from its 'outputs' member alone it
@ -2208,6 +2208,16 @@ class Model {
// addresses. See Operator::inputs, Operator::outputs.
std::unordered_map<string, std::unique_ptr<Array>> arrays;
};
// OperatorSignature contains the information required to making versioning
// decisions.
struct OperatorSignature {
// The operator.
const Operator* op;
// The model in which the operator resides.
const Model* model;
};
} // namespace toco
#endif // TENSORFLOW_LITE_TOCO_MODEL_H_

View File

@ -106,16 +106,17 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
namespace details {
OperatorKey::OperatorKey(
const ::toco::Operator& op,
const ::toco::OperatorSignature& op_signature,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
bool enable_select_tf_ops) {
// Get the op name (by Toco definition).
const ::toco::Operator& op = *op_signature.op;
string name = HelpfulOperatorTypeName(op);
bool is_builtin = false;
const auto& builtin_ops = GetBuiltinOpsMap();
if (ops_by_type.count(op.type) != 0) {
version_ = ops_by_type.at(op.type)->GetVersion(op);
version_ = ops_by_type.at(op.type)->GetVersion(op_signature);
name = ops_by_type.at(op.type)->name();
is_builtin = (builtin_ops.count(name) > 0);
}
@ -190,7 +191,8 @@ void LoadOperatorsMap(
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
keys.insert(OperatorKey(*op, ops_by_type, enable_select_tf_ops));
const toco::OperatorSignature op_signature = {op.get(), &model};
keys.insert(OperatorKey(op_signature, ops_by_type, enable_select_tf_ops));
}
// Now assign indices to them and fill in the map.
int index = 0;
@ -301,8 +303,9 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
std::map<int, Offset<OperatorCode>> ordered_opcodes;
for (const auto& op : model.operators) {
const details::OperatorKey operator_key =
details::OperatorKey(*op, ops_by_type, params.enable_select_tf_ops);
const toco::OperatorSignature op_signature = {op.get(), &model};
const details::OperatorKey operator_key = details::OperatorKey(
op_signature, ops_by_type, params.enable_select_tf_ops);
int op_index = operators_map.at(operator_key);
flatbuffers::Offset<flatbuffers::String> custom_code = 0;
@ -349,9 +352,9 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
for (const string& output : op->outputs) {
outputs.push_back(tensors_map.at(output));
}
const auto key =
details::OperatorKey(*op, ops_by_type, params.enable_select_tf_ops);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type,
params.enable_select_tf_ops);
int op_index = operators_map.at(key);
auto tflite_op_it = ops_by_type.find(op->type);

View File

@ -76,7 +76,7 @@ inline void Export(const Model& model, string* output_file_contents) {
namespace details {
// A maps from tensor name to its final position in the TF Lite buffer.
// A map from tensor name to its final position in the TF Lite buffer.
using TensorsMap = std::unordered_map<string, int>;
// A key to identify an operator.
@ -88,7 +88,7 @@ class OperatorKey {
// Construct OperatorKey by Toco op.
OperatorKey(
const ::toco::Operator& op,
const ::toco::OperatorSignature& op_signature,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
bool enable_select_tf_ops);
@ -158,7 +158,7 @@ class OperatorKey {
std::string flex_tensorflow_op_;
};
// A maps from operator type to its final position in the TF Lite buffer.
// A map from OperatorKey to its final position in the TF Lite buffer.
using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);

View File

@ -301,8 +301,9 @@ class FakeConvolutionOperator
OperatorType::kConv) {}
// Returning the op version according to the op parameters.
int GetVersion(const Operator& op) const override {
const TocoOperator& conv_op = static_cast<const TocoOperator&>(op);
int GetVersion(const OperatorSignature& op_signature) const override {
const TocoOperator& conv_op =
static_cast<const TocoOperator&>(*op_signature.op);
if (conv_op.dilation_width_factor != 1 ||
conv_op.dilation_height_factor != 1) {
// Version 2 if dilation is used.
@ -448,22 +449,58 @@ TEST_F(VersionedOpExportTest, Export) {
}
TEST(OperatorKeyTest, TestBuiltinOp) {
Model model;
auto op = absl::make_unique<ConvOperator>();
// Test a normal float operation.
op->inputs = {"input", "filter"};
op->outputs = {"output"};
Array& input_array = model.GetOrCreateArray(op->inputs[0]);
Array& filter_array = model.GetOrCreateArray(op->inputs[1]);
Array& output_array = model.GetOrCreateArray(op->outputs[0]);
input_array.data_type = ArrayDataType::kFloat;
filter_array.data_type = ArrayDataType::kFloat;
output_array.data_type = ArrayDataType::kFloat;
const auto ops_by_type = BuildOperatorByTypeMap();
const auto key = details::OperatorKey(*op, ops_by_type, false);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CONV_2D);
EXPECT_EQ(key.custom_code(), "");
EXPECT_EQ(key.version(), 1);
}
TEST(OperatorKeyTest, TestBuiltinOpWithVersionedInputTypes) {
Model model;
auto op = absl::make_unique<DequantizeOperator>();
op->inputs = {"input"};
op->outputs = {"output"};
Array& input_array = model.GetOrCreateArray(op->inputs[0]);
Array& output_array = model.GetOrCreateArray(op->outputs[0]);
input_array.data_type = ArrayDataType::kInt8;
output_array.data_type = ArrayDataType::kFloat;
const auto ops_by_type = BuildOperatorByTypeMap();
// Test a signed int8 dequantize operation.
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_DEQUANTIZE);
EXPECT_EQ(key.custom_code(), "");
EXPECT_EQ(key.version(), 2);
}
TEST(OperatorKeyTest, TestCustomOp) {
Model model;
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
op->tensorflow_op = "MyCrazyCustomOp";
const auto ops_by_type = BuildOperatorByTypeMap();
const auto key = details::OperatorKey(*op, ops_by_type, false);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code(), "MyCrazyCustomOp");
@ -471,12 +508,14 @@ TEST(OperatorKeyTest, TestCustomOp) {
}
TEST(OperatorKeyTest, TestFlexOp) {
Model model;
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
op->tensorflow_op = "BatchMatMul";
const auto ops_by_type = BuildOperatorByTypeMap();
{
const auto key = details::OperatorKey(*op, ops_by_type, false);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, false);
// It shouldn't be converted to Flex op if `allow_flex_op` is false.
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code(), "BatchMatMul");
@ -488,7 +527,8 @@ TEST(OperatorKeyTest, TestFlexOp) {
{
// Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
// is true.
const auto key = details::OperatorKey(*op, ops_by_type, true);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code(), "FlexBatchMatMul");
EXPECT_EQ(key.version(), 1);
@ -498,11 +538,13 @@ TEST(OperatorKeyTest, TestFlexOp) {
}
TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
Model model;
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
op->tensorflow_op = "Merge";
const auto ops_by_type = BuildOperatorByTypeMap();
const auto key = details::OperatorKey(*op, ops_by_type, true);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code(), "FlexMerge");
@ -514,11 +556,13 @@ TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
}
TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
Model model;
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
op->tensorflow_op = "HashTableV2";
const auto ops_by_type = BuildOperatorByTypeMap();
const auto key = details::OperatorKey(*op, ops_by_type, true);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code(), "HashTableV2");
@ -532,6 +576,7 @@ TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
// Test Toco-supported/TFLite-unsupported operators.
Model model;
// TODO(ycling): The test will be broken if TensorFlowAssert is implemented in
// TFLite. Find a more robust way to test the fallback logic.
auto op = absl::make_unique<TensorFlowAssertOperator>();
@ -541,7 +586,8 @@ TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
{
// If NodeDef isn't retained in the Toco op, a regular custom op
// will be exported.
const auto key = details::OperatorKey(*op, ops_by_type, true);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code(), "Assert");
EXPECT_EQ(key.version(), 1);
@ -556,7 +602,8 @@ TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
{
// If NodeDef is retained in the Toco op, a Flex op will be exported.
const auto key = details::OperatorKey(*op, ops_by_type, true);
const toco::OperatorSignature op_signature = {op.get(), &model};
const auto key = details::OperatorKey(op_signature, ops_by_type, true);
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code(), "FlexAssert");
EXPECT_EQ(key.version(), 1);

View File

@ -14,19 +14,20 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/toco/tflite/operator.h"
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
// graph_transformation module.
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/lite/toco/tflite/custom_operator.h"
#include "tensorflow/lite/toco/tflite/simple_operator.h"
#include "tensorflow/lite/toco/tflite/types.h"
#include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/util/ptr_util.h"
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
// graph_transformation module.
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/lite/toco/tflite/custom_operator.h"
#include "tensorflow/lite/toco/tflite/simple_operator.h"
#include "tensorflow/lite/toco/tflite/types.h"
#include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h"
namespace toco {
@ -60,7 +61,9 @@ class AveragePool
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Convolution
@ -92,7 +95,9 @@ class Convolution
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class DepthwiseConvolution
@ -126,8 +131,9 @@ class DepthwiseConvolution
op->dilation_height_factor = options.dilation_h_factor();
}
int GetVersion(const Operator& op) const override {
const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& conv_op =
static_cast<const DepthwiseConvOperator&>(*op_signature.op);
if (conv_op.dilation_width_factor != 1 ||
conv_op.dilation_height_factor != 1) {
return 2;
@ -155,7 +161,9 @@ class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class SpaceToBatchND
@ -174,7 +182,9 @@ class SpaceToBatchND
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
@ -196,7 +206,9 @@ class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
@ -218,7 +230,9 @@ class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class BatchToSpaceND
@ -237,7 +251,9 @@ class BatchToSpaceND
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
@ -258,7 +274,9 @@ class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
op->dst_data_type = DataType::Deserialize(options.out_data_type());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Concatenation
@ -278,7 +296,9 @@ class Concatenation
op->axis = options.axis();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
@ -292,7 +312,9 @@ class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
op->block_size = m["block_size"].AsInt64();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class FakeQuant
@ -315,9 +337,8 @@ class FakeQuant
op->num_bits = options.num_bits();
op->narrow_range = options.narrow_range();
}
int GetVersion(const Operator& op) const override {
const auto& fq_op = static_cast<const FakeQuantOperator&>(op);
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
return fq_op.narrow_range ? 2 : 1;
}
};
@ -369,8 +390,9 @@ class FullyConnected
}
}
int GetVersion(const Operator& op) const override {
const auto& fc_op = static_cast<const FullyConnectedOperator&>(op);
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fc_op =
static_cast<const FullyConnectedOperator&>(*op_signature.op);
return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1
: 2;
}
@ -392,7 +414,9 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
op->axis = {options.axis()};
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
@ -414,7 +438,9 @@ class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
op->rank = options.rank();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class L2Normalization
@ -436,7 +462,9 @@ class L2Normalization
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
@ -465,7 +493,9 @@ class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class LocalResponseNormalization
@ -490,7 +520,9 @@ class LocalResponseNormalization
op->beta = options.beta();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
@ -519,7 +551,9 @@ class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
@ -541,7 +575,9 @@ class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
@ -558,7 +594,9 @@ class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Tile
@ -574,7 +612,9 @@ class Tile
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
@ -591,7 +631,9 @@ class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Reshape
@ -614,7 +656,9 @@ class Reshape
options.new_shape()->end());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Softmax
@ -633,7 +677,9 @@ class Softmax
op->beta = options.beta();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class SpaceToDepth
@ -653,7 +699,9 @@ class SpaceToDepth
op->block_size = options.block_size();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Transpose
@ -670,7 +718,9 @@ class Transpose
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
@ -713,8 +763,9 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
}
}
int GetVersion(const Operator& op) const override {
const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& lstm_op =
static_cast<const LstmCellOperator&>(*op_signature.op);
switch (lstm_op.kernel_type) {
case LstmCellOperator::KERNEL_FULL:
return 1;
@ -770,7 +821,9 @@ class UnidirectionalSequenceLstm
::tflite::ActivationFunctionType_TANH);
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
@ -796,7 +849,9 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
op->keep_dims = options.keep_dims();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Sum
@ -815,7 +870,9 @@ class Sum
op->keep_dims = options.keep_dims();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ReduceMax
@ -834,7 +891,9 @@ class ReduceMax
op->keep_dims = options.keep_dims();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ReduceMin
@ -853,7 +912,9 @@ class ReduceMin
op->keep_dims = options.keep_dims();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ReduceProd
@ -872,7 +933,9 @@ class ReduceProd
op->keep_dims = options.keep_dims();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ReduceAny
@ -891,7 +954,9 @@ class ReduceAny
op->keep_dims = options.keep_dims();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ResizeBilinear
@ -911,7 +976,9 @@ class ResizeBilinear
op->align_corners = options.align_corners();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ResizeNearestNeighbor
@ -932,7 +999,9 @@ class ResizeNearestNeighbor
op->align_corners = options.align_corners();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Squeeze
@ -955,7 +1024,9 @@ class Squeeze
options.squeeze_dims()->end());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Split
@ -975,7 +1046,9 @@ class Split
op->num_split = options.num_splits();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class SplitV
@ -995,7 +1068,9 @@ class SplitV
op->num_split = options.num_splits();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class StridedSlice
@ -1021,7 +1096,9 @@ class StridedSlice
op->shrink_axis_mask = options.shrink_axis_mask();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
@ -1037,7 +1114,9 @@ class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
@ -1056,7 +1135,9 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
op->output_data_type = DataType::Deserialize(options.output_type());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
@ -1075,7 +1156,9 @@ class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
op->output_data_type = DataType::Deserialize(options.output_type());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class TransposeConv
@ -1100,7 +1183,9 @@ class TransposeConv
op->stride_height = options.stride_h();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class SparseToDense
@ -1121,7 +1206,9 @@ class SparseToDense
op->validate_indices = options.validate_indices();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class ExpandDims
@ -1139,7 +1226,9 @@ class ExpandDims
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
@ -1159,7 +1248,9 @@ class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
op->axis = options.axis();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Shape
@ -1179,7 +1270,9 @@ class Shape
op->output_data_type = DataType::Deserialize(options.out_type());
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
@ -1196,7 +1289,9 @@ class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
op->axis = options.axis();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class CTCBeamSearchDecoder
@ -1217,7 +1312,9 @@ class CTCBeamSearchDecoder
op->merge_repeated = m["merge_repeated"].AsBool();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
@ -1235,7 +1332,9 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
op->axis = options.axis();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class LeakyRelu
@ -1253,7 +1352,9 @@ class LeakyRelu
op->alpha = options.alpha();
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class SquaredDifference
@ -1272,7 +1373,9 @@ class SquaredDifference
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class MirrorPad
@ -1295,7 +1398,9 @@ class MirrorPad
: MirrorPadMode::kSymmetric;
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
@ -1472,8 +1577,8 @@ class TensorFlowUnsupported : public BaseOperator {
node_def.SerializeToString(&op->tensorflow_node_def);
}
int GetVersion(const Operator& op) const override {
// TODO(ycling): Deisng and implement a way to plumb the version of
int GetVersion(const OperatorSignature& op_signature) const override {
// TODO(ycling): Design and implement a way to plumb the version of
// custom ops.
return 1;
}
@ -1497,11 +1602,13 @@ class Dequantize
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const Operator& op) const override {
// TODO(suharshs): Dequantize now supports INT8 in addition to
// QUANTIZED_UINT8. When TOCO can create models with INT8, we need
// to find a way to see the type here and return version 2. Right now
// version 2 will only be added by post training quantization tools.
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// Version 2 supports signed int8 input types.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};
@ -1534,6 +1641,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
ops.push_back(MakeUnique<DepthwiseConvolution>(
::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
OperatorType::kDepthwiseConv));
ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
OperatorType::kDequantize));
ops.push_back(
MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
OperatorType::kFullyConnected));
@ -1645,8 +1754,6 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
// when custom ops are exported but SimpleOperator bypasses thoses. To
// prevent user confusion we are settling on using SimpleOperator only for
// builtins.
ops.push_back(MakeUnique<SimpleOperator<DequantizeOperator>>(
"DEQUANTIZE", OperatorType::kDequantize));
ops.push_back(
MakeUnique<SimpleOperator<FloorOperator>>("FLOOR", OperatorType::kFloor));
ops.push_back(

View File

@ -87,15 +87,15 @@ class BaseOperator {
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const = 0;
// Get the op version by op parameters.
// The function need to be overridden to return the op version based on the
// Get the op version using the OperatorSignature.
// The function needs to be overridden to return the op version based on the
// parameters. Note:
// * The first version for each op should be 1 (to be consistent with the
// default value in Flatbuffer. `return 1;` is okay for newly implemented
// ops.
// * When multiple versions are defined for an op, this function need to be
// * When multiple versions are defined for an op, this function needs to be
// overridden. (See example in `operator_test.cc`)
virtual int GetVersion(const Operator& op) const = 0;
virtual int GetVersion(const OperatorSignature& op_signature) const = 0;
// Given a Toco `Operator`, return a list of booleans indicating the op
// mutates which input variables.

View File

@ -111,8 +111,6 @@ class OperatorTest : public ::testing::Test {
};
TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<DequantizeOperator>("DEQUANTIZE",
OperatorType::kDequantize);
CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
@ -469,6 +467,12 @@ TEST_F(OperatorTest, BuiltinArgMin) {
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
}
TEST_F(OperatorTest, BuiltinDequantize) {
DequantizeOperator op;
auto output_toco_op = SerializeAndDeserialize(
GetOperator("DEQUANTIZE", OperatorType::kDequantize), op);
}
TEST_F(OperatorTest, BuiltinTransposeConv) {
TransposeConvOperator op;
op.stride_width = 123;

View File

@ -42,7 +42,9 @@ class SimpleOperator : public BaseOperator {
return std::unique_ptr<Operator>(new T);
}
int GetVersion(const Operator& op) const override { return 1; }
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
} // namespace tflite