Jared Duke 6be604aaac Reland (Attempt #3) PR #35985: [TFLite int16] 16-bit version of ADD/SUB reference kernel operators
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/35985

This PR is one of steps to extend 8-bit quantization to support symmetric 16-bit activations.

Each activation is of type int16 and symmetric around zero. The weight tensor precision remains at 8-bit signed values. The bias is set to int64 precision.

In this PR we introduce implementation and tests for ADD/SUB kernel reference function.
The specification of this operator:

SUB
  Input 0:
    data_type  : int16
    range      : [-32768, 32767]
    granularity: per-tensor, zero_point=0
  Input 1:
    data_type  : int16
    range      : [-32768, 32767]
    granularity: per-tensor, zero_point=0
  Output 0:
    data_type  : int16
    range      : [-32768, 32767]
    granularity: per-tensor, zero_point=0

ADD
  Input 0:
    data_type  : int16
    range      : [-32768, 32767]
    granularity: per-tensor, zero_point=0
  Input 1:
    data_type  : int16
    range      : [-32768, 32767]
    granularity: per-tensor, zero_point=0
  Output 0:
    data_type  : int16
    range      : [-32768, 32767]
    granularity: per-tensor, zero_point=0
Copybara import of the project:

--
b94cb4732ab536828e565fd1c7b557f124432e29 by Elena Zhelezina <elena.zhelezina@arm.com>:

Added 16-bit version of ADD/SUB operators. Broadcasting is included.

--
924d0b72c568f249f2fd224a942f8922524bfede by Elena Zhelezina <elena.zhelezina@arm.com>:

Addressed reviewer comments.

--
dd0d9e8f03d1fb1b887609fffb8ea5a86638c63e by Elena Zhelezina <elena.zhelezina@arm.com>:

Added versioning to ADD/SUB + some rework of the existing code.

--
abae3fd9a9b894c07d13c9ef416092c9004bc913 by Elena Zhelezina <elena.zhelezina@arm.com>:

Added versioning for ADD/SUB with new option in the schema.fbs
schema_generated.h is edited manually.

--
24f3f5593a06d24fa1ca6be257f1265b5293d492 by Elena Zhelezina <elena.zhelezina@arm.com>:

Fix for broken build.

--
d252fe175aef3a1a08c65155815efb706aa80afd by Elena Zhelezina <elena.zhelezina@arm.com>:

Fix for the failing internal test for NN delegates.

--
2223a5c380bb821eb05f8034703c687269353e32 by Elena Zhelezina <elena.zhelezina@arm.com>:

Fix for asan failures.

Change-Id: I2cf421ddda7f9e802202239136ab062bcd63b4aa

--
3c219a46ce5888e8e402b64cc943ac6522156ef5 by Elena Zhelezina <elena.zhelezina@arm.com>:

Added broadcast params to addsub structure.

Change-Id: I61d7d4a94087d052a782890799211031f6ed3015

--
9131a38c776109cdbcfa60be602667ec7aafe00f by Elena Zhelezina <elena.zhelezina@arm.com>:

Corrected defaults.

Change-Id: I9ea50c75014cc03ac91fdef0f5b4fe11395f7074
PiperOrigin-RevId: 324865496
2020-08-04 12:31:25 -07:00

2142 lines
84 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/toco/tflite/operator.h"
#include <map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/util/ptr_util.h"
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
// graph_transformation module.
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/lite/toco/tflite/custom_operator.h"
#include "tensorflow/lite/toco/tflite/simple_operator.h"
#include "tensorflow/lite/toco/tflite/types.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
namespace toco {
namespace tflite {
// LINT.IfChange
::tflite::TensorType GetTensorType(const ArrayDataType type) {
const std::map<ArrayDataType, ::tflite::TensorType> tensor_type_map = {
{ArrayDataType::kBool, ::tflite::TensorType_BOOL},
{ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
{ArrayDataType::kInt8, ::tflite::TensorType_INT8},
{ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
{ArrayDataType::kInt16, ::tflite::TensorType_INT16},
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kString, ::tflite::TensorType_STRING},
{ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
{ArrayDataType::kComplex128, ::tflite::TensorType_COMPLEX128},
{ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
{ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};
auto it = tensor_type_map.find(type);
if (it != tensor_type_map.end()) {
return it->second;
}
return static_cast<::tflite::TensorType>(-1);
}
::tflite::OpSignature GetVersioningOpSig(
const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
std::vector<::tflite::TensorType> input_types, output_types;
for (const auto& input_name : op_signature.op->inputs) {
::tflite::TensorType input_type = static_cast<::tflite::TensorType>(-1);
if (op_signature.model->HasArray(input_name)) {
const Array& input_array = op_signature.model->GetArray(input_name);
input_type = GetTensorType(input_array.data_type);
}
input_types.push_back(input_type);
}
for (const auto& output_name : op_signature.op->outputs) {
::tflite::TensorType output_type = static_cast<::tflite::TensorType>(-1);
if (op_signature.model->HasArray(output_name)) {
const Array& output_array = op_signature.model->GetArray(output_name);
output_type = GetTensorType(output_array.data_type);
}
output_types.push_back(output_type);
}
return ::tflite::OpSignature{op, input_types, output_types};
}
class AveragePool
: public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
::tflite::BuiltinOptions_Pool2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
op.stride_height, op.kwidth,
op.kheight, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->kwidth = options.filter_width();
op->kheight = options.filter_height();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class Convolution
: public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
::tflite::BuiltinOptions_Conv2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
op.stride_height, activation_function,
op.dilation_width_factor,
op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->dilation_width_factor = options.dilation_w_factor();
op->dilation_height_factor = options.dilation_h_factor();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class DepthwiseConvolution
: public BuiltinOperator<DepthwiseConvOperator,
::tflite::DepthwiseConv2DOptions,
::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDepthwiseConv2DOptions(
*builder, padding, op.stride_width, op.stride_height,
op.depth_multiplier, activation_function, op.dilation_width_factor,
op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->depth_multiplier = options.depth_multiplier();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
op->dilation_width_factor = options.dilation_w_factor();
op->dilation_height_factor = options.dilation_h_factor();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& conv_op =
static_cast<const DepthwiseConvOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.depthwise_conv_2d.dilation_w_factor =
conv_op.dilation_width_factor;
op_sig.options.depthwise_conv_2d.dilation_h_factor =
conv_op.dilation_height_factor;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
::tflite::BuiltinOptions_AddOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateAddOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
::tflite::BuiltinOptions_AddNOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateAddNOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class SpaceToBatchND
: public BuiltinOperator<SpaceToBatchNDOperator,
::tflite::SpaceToBatchNDOptions,
::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSpaceToBatchNDOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.single_input_op.num_dims =
input_array.shape().dimensions_count();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
::tflite::BuiltinOptions_SubOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateSubOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input1_name = op_signature.op->inputs[0];
const std::string& input2_name = op_signature.op->inputs[1];
const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
if (input1_array.has_shape() && input2_array.has_shape()) {
op_sig.options.addsub.num_dims =
std::max(input1_array.shape().dimensions_count(),
input2_array.shape().dimensions_count());
op_sig.options.addsub.need_broadcast =
(input1_array.shape() != input2_array.shape());
}
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
::tflite::BuiltinOptions_DivOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDivOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input1_name = op_signature.op->inputs[0];
const std::string& input2_name = op_signature.op->inputs[1];
const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
if (input1_array.has_shape() && input2_array.has_shape()) {
op_sig.options.broadcast.num_dims =
std::max(input1_array.shape().dimensions_count(),
input2_array.shape().dimensions_count());
op_sig.options.broadcast.need_broadcast =
(input1_array.shape() != input2_array.shape());
}
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class BatchToSpaceND
: public BuiltinOperator<BatchToSpaceNDOperator,
::tflite::BatchToSpaceNDOptions,
::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateBatchToSpaceNDOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.single_input_op.num_dims =
input_array.shape().dimensions_count();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
::tflite::BuiltinOptions_CastOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateCastOptions(*builder,
DataType::Serialize(op.src_data_type),
DataType::Serialize(op.dst_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->src_data_type = DataType::Deserialize(options.in_data_type());
op->dst_data_type = DataType::Deserialize(options.out_data_type());
}
};
class Concatenation
: public BuiltinOperator<ConcatenationOperator,
::tflite::ConcatenationOptions,
::tflite::BuiltinOptions_ConcatenationOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateConcatenationOptions(*builder, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->axis = options.axis();
}
};
class DepthToSpace
: public BuiltinOperator<DepthToSpaceOperator,
::tflite::DepthToSpaceOptions,
::tflite::BuiltinOptions_DepthToSpaceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->block_size = options.block_size();
}
};
class FakeQuant
: public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
::tflite::BuiltinOptions_FakeQuantOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateFakeQuantOptions(
*builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
auto* minmax = new MinMax;
minmax->min = options.min();
minmax->max = options.max();
op->minmax.reset(minmax);
op->num_bits = options.num_bits();
op->narrow_range = options.narrow_range();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.fakequant.narrow_range = fq_op.narrow_range;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class FullyConnected
: public BuiltinOperator<FullyConnectedOperator,
::tflite::FullyConnectedOptions,
::tflite::BuiltinOptions_FullyConnectedOptions> {
public:
using BuiltinOperator::BuiltinOperator;
::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
FullyConnectedWeightsFormat fmt) const {
switch (fmt) {
case FullyConnectedWeightsFormat::kDefault:
return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
case FullyConnectedWeightsFormat::kShuffled4x16Int8:
return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
default:
LOG(ERROR) << "Unhandled FC weights format";
return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
}
}
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateFullyConnectedOptions(
*builder, activation_function, GetWeightFormat(op.weights_format));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
switch (options.weights_format()) {
case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
op->weights_format = FullyConnectedWeightsFormat::kDefault;
break;
case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
break;
default:
LOG(ERROR) << "Unhandled FC weights format";
op->weights_format = FullyConnectedWeightsFormat::kDefault;
}
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fc_op =
static_cast<const FullyConnectedOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
op_sig.options.fully_connected.weights_format =
GetWeightFormat(fc_op.weights_format);
op_sig.options.fully_connected.sparse_weight = false;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
::tflite::BuiltinOptions_GatherOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
int axis = op.axis ? op.axis.value() : 0;
return ::tflite::CreateGatherOptions(*builder, axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->axis = {options.axis()};
}
};
class GatherNd
: public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
::tflite::BuiltinOptions_GatherNdOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateGatherNdOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
::tflite::BuiltinOptions_SVDFOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
op->rank = options.rank();
}
};
class L2Normalization
: public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
::tflite::BuiltinOptions_L2NormOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateL2NormOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
::tflite::BuiltinOptions_Pool2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
op.stride_height, op.kwidth,
op.kheight, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->kwidth = options.filter_width();
op->kheight = options.filter_height();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class LocalResponseNormalization
: public BuiltinOperator<
LocalResponseNormalizationOperator,
::tflite::LocalResponseNormalizationOptions,
::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateLocalResponseNormalizationOptions(
*builder, op.range, op.bias, op.alpha, op.beta);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->range = options.radius();
op->bias = options.bias();
op->alpha = options.alpha();
op->beta = options.beta();
}
};
class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
::tflite::BuiltinOptions_Pool2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
op.stride_height, op.kwidth,
op.kheight, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->kwidth = options.filter_width();
op->kheight = options.filter_height();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
::tflite::BuiltinOptions_MulOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateMulOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input1_name = op_signature.op->inputs[0];
const std::string& input2_name = op_signature.op->inputs[1];
const std::string& output_name = op_signature.op->outputs[0];
const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name);
const Array& output_array = op_signature.model->GetArray(output_name);
const auto& input1_quant = input1_array.quantization_params;
const auto& input2_quant = input2_array.quantization_params;
const auto& output_quant = output_array.quantization_params;
const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
const float output_scale = output_quant ? output_quant->scale : 0.0f;
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.mul.input1_scale = input1_scale;
op_sig.options.mul.input2_scale = input2_scale;
op_sig.options.mul.output_scale = output_scale;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
::tflite::BuiltinOptions_PadOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreatePadOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Tile
: public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
::tflite::BuiltinOptions_TileOptions> {
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateTileOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
::tflite::BuiltinOptions_PadV2Options> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreatePadV2Options(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Reshape
: public BuiltinOperator<TensorFlowReshapeOperator,
::tflite::ReshapeOptions,
::tflite::BuiltinOptions_ReshapeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReshapeOptions(*builder,
builder->CreateVector(op.shape));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->shape.insert(op->shape.end(), options.new_shape()->begin(),
options.new_shape()->end());
}
};
class Softmax
: public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
::tflite::BuiltinOptions_SoftmaxOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->beta = options.beta();
}
};
class SpaceToDepth
: public BuiltinOperator<SpaceToDepthOperator,
::tflite::SpaceToDepthOptions,
::tflite::BuiltinOptions_SpaceToDepthOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->block_size = options.block_size();
}
};
class Transpose
: public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
::tflite::BuiltinOptions_TransposeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateTransposeOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
::tflite::BuiltinOptions_LSTMOptions> {
public:
using BuiltinOperator::BuiltinOperator;
::tflite::LSTMKernelType GetKernelType(
LstmCellOperator::KernelType type) const {
switch (type) {
case LstmCellOperator::KERNEL_BASIC:
return ::tflite::LSTMKernelType_BASIC;
break;
case LstmCellOperator::KERNEL_FULL:
return ::tflite::LSTMKernelType_FULL;
break;
default:
LOG(ERROR) << "Unhandled Kernel Type";
return static_cast<::tflite::LSTMKernelType>(-1);
}
}
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
/*proj_clip=*/0.0, kernel_type);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
CHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
switch (options.kernel_type()) {
case ::tflite::LSTMKernelType_BASIC:
op->kernel_type = LstmCellOperator::KERNEL_BASIC;
break;
case ::tflite::LSTMKernelType_FULL:
op->kernel_type = LstmCellOperator::KERNEL_FULL;
break;
}
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& lstm_op =
static_cast<const LstmCellOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.lstm.kernel_type = GetKernelType(lstm_op.kernel_type);
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
switch (lstm_op.kernel_type) {
case LstmCellOperator::KERNEL_FULL: {
mutating_input_variables[kInputActivationStateTensor] = true;
mutating_input_variables[kInputCellStateTensor] = true;
break;
}
case LstmCellOperator::KERNEL_BASIC: {
mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
break;
}
}
return mutating_input_variables;
}
};
class UnidirectionalSequenceLstm
: public BuiltinOperator<
UnidirectionalSequenceLstmOperator,
::tflite::UnidirectionalSequenceLSTMOptions,
::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
/*proj_clip=*/0.0,
/*time_major=*/true);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
mutating_input_variables[kInputActivationStateTensor] = true;
mutating_input_variables[kInputCellStateTensor] = true;
return mutating_input_variables;
}
};
class BidirectionalSequenceLstm
: public BuiltinOperator<
BidirectionalSequenceLstmOperator,
::tflite::BidirectionalSequenceLSTMOptions,
::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateBidirectionalSequenceLSTMOptions(
*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
/*proj_clip=*/0.0,
/*merge_outputs=*/op.merge_outputs,
/*time_major=*/true);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
op->merge_outputs = options.merge_outputs();
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
// Forward input activation state.
mutating_input_variables[35] = true;
// Forward input cell state.
mutating_input_variables[36] = true;
// Backward input activation state.
mutating_input_variables[37] = true;
// Backward input cell state.
mutating_input_variables[38] = true;
return mutating_input_variables;
}
};
class BidirectionalSequenceRnn
: public BuiltinOperator<
BidirectionalSequenceRnnOperator,
::tflite::BidirectionalSequenceRNNOptions,
::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateBidirectionalSequenceRNNOptions(
*builder, /*time_major=*/true,
/*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*merge_outputs=*/op.merge_outputs);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
op->merge_outputs = options.merge_outputs();
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
// Forward hidden state.
mutating_input_variables[4] = true;
// Backward hidden state.
mutating_input_variables[8] = true;
return mutating_input_variables;
}
};
class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class Sum
: public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceMax
: public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceMin
: public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceProd
: public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceAny
: public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ResizeBilinear
: public BuiltinOperator<ResizeBilinearOperator,
::tflite::ResizeBilinearOptions,
::tflite::BuiltinOptions_ResizeBilinearOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
op.half_pixel_centers);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->align_corners = options.align_corners();
op->half_pixel_centers = options.half_pixel_centers();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& resize_bilinear_op =
static_cast<const ResizeBilinearOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize.half_pixel_centers =
resize_bilinear_op.half_pixel_centers;
op_sig.options.resize.align_corners = resize_bilinear_op.align_corners;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class ResizeNearestNeighbor
: public BuiltinOperator<
ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateResizeNearestNeighborOptions(
*builder, op.align_corners, op.half_pixel_centers);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->align_corners = options.align_corners();
op->half_pixel_centers = options.half_pixel_centers();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& resize_nn_op =
static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize.half_pixel_centers = resize_nn_op.half_pixel_centers;
op_sig.options.resize.align_corners = resize_nn_op.align_corners;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Squeeze
: public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
::tflite::BuiltinOptions_SqueezeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->squeeze_dims.insert(op->squeeze_dims.end(),
options.squeeze_dims()->begin(),
options.squeeze_dims()->end());
}
};
class Split
: public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
::tflite::BuiltinOptions_SplitOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSplitOptions(*builder, op.num_split);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->num_split = options.num_splits();
}
};
class SplitV
: public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
::tflite::BuiltinOptions_SplitVOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSplitVOptions(*builder, op.num_split);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->num_split = options.num_splits();
}
};
class StridedSlice
: public BuiltinOperator<StridedSliceOperator,
::tflite::StridedSliceOptions,
::tflite::BuiltinOptions_StridedSliceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateStridedSliceOptions(
*builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
op.new_axis_mask, op.shrink_axis_mask);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->begin_mask = options.begin_mask();
op->end_mask = options.end_mask();
op->ellipsis_mask = options.ellipsis_mask();
op->new_axis_mask = options.new_axis_mask();
op->shrink_axis_mask = options.shrink_axis_mask();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& ss_op =
static_cast<const StridedSliceOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.single_input_op.num_dims = ss_op.start_indices.size();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
::tflite::BuiltinOptions_TopKV2Options> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateTopKV2Options(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
::tflite::BuiltinOptions_ArgMaxOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateArgMaxOptions(
*builder, DataType::Serialize(op.output_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.output_type());
}
};
class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
::tflite::BuiltinOptions_ArgMinOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateArgMinOptions(
*builder, DataType::Serialize(op.output_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.output_type());
}
};
class TransposeConv
: public BuiltinOperator<TransposeConvOperator,
::tflite::TransposeConvOptions,
::tflite::BuiltinOptions_TransposeConvOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
return ::tflite::CreateTransposeConvOptions(
*builder, padding, op.stride_width, op.stride_height);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
}
};
class SparseToDense
: public BuiltinOperator<SparseToDenseOperator,
::tflite::SparseToDenseOptions,
::tflite::BuiltinOptions_SparseToDenseOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->validate_indices = options.validate_indices();
}
};
class ExpandDims
: public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
::tflite::BuiltinOptions_ExpandDimsOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateExpandDimsOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
::tflite::BuiltinOptions_PackOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->values_count = options.values_count();
op->axis = options.axis();
}
};
class Shape
: public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
::tflite::BuiltinOptions_ShapeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateShapeOptions(
*builder, DataType::Serialize(op.output_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.out_type());
}
};
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
::tflite::BuiltinOptions_OneHotOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateOneHotOptions(*builder, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->axis = options.axis();
}
};
class CTCBeamSearchDecoder
: public CustomOperator<CTCBeamSearchDecoderOperator> {
public:
using CustomOperator::CustomOperator;
void WriteOptions(const TocoOperator& op,
flexbuffers::Builder* fbb) const override {
fbb->Int("beam_width", op.beam_width);
fbb->Int("top_paths", op.top_paths);
fbb->Bool("merge_repeated", op.merge_repeated);
}
void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
op->beam_width = m["beam_width"].AsInt32();
op->top_paths = m["top_paths"].AsInt32();
op->merge_repeated = m["merge_repeated"].AsBool();
}
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
::tflite::BuiltinOptions_UnpackOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->num = options.num();
op->axis = options.axis();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8/uint8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8 ||
input_array.data_type == ArrayDataType::kUint8) {
return 2;
}
// If the op take bool input, it is version 3.
if (input_array.data_type == ArrayDataType::kBool) {
return 3;
}
return 1;
}
};
class LeakyRelu
: public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
::tflite::BuiltinOptions_LeakyReluOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->alpha = options.alpha();
}
};
class SquaredDifference
: public BuiltinOperator<
SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
::tflite::BuiltinOptions_SquaredDifferenceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSquaredDifferenceOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class MirrorPad
: public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
::tflite::BuiltinOptions_MirrorPadOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateMirrorPadOptions(
*builder, op.mode == MirrorPadMode::kReflect
? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
: ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
? MirrorPadMode::kReflect
: MirrorPadMode::kSymmetric;
}
};
class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
::tflite::BuiltinOptions_UniqueOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
return ::tflite::CreateUniqueOptions(
*builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
? ::tflite::TensorType::TensorType_INT64
: ::tflite::TensorType_INT32);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
unique_op->idx_out_type =
options.idx_out_type() == ::tflite::TensorType_INT64
? toco::ArrayDataType::kInt64
: toco::ArrayDataType::kInt32;
}
};
class UnidirectionalSequenceRnn
: public BuiltinOperator<UnidirectionalSequenceRnnOperator,
::tflite::SequenceRNNOptions,
::tflite::BuiltinOptions_SequenceRNNOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSequenceRNNOptions(
*builder, /*time_major=*/true,
/*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
mutating_input_variables[4] = true;
return mutating_input_variables;
}
};
class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
::tflite::BuiltinOptions_WhereOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateWhereOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
const std::string& tensorflow_node_def) {
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
if (!node_def.ParseFromString(tensorflow_node_def)) {
LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
return {};
}
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(tensorflow_node_def);
});
fbb->Finish();
LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const std::string& name, OperatorType type,
bool enable_select_tf_ops)
: BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto fbb =
WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
if (fbb) {
return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
} else {
return Options::Custom(0);
}
}
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
// Deserializing Flex ops doesn't work now.
// TODO(ycling): Revisit and decide if we should fix the flow for importing
// TFLite models with Flex ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
flexbuffers::GetRoot(custom_options->data(), custom_options->size())
.AsMap();
ReadOptions(flexbuffer_map, op.get());
}
return std::unique_ptr<Operator>(op.release());
}
std::unique_ptr<flexbuffers::Builder> WriteOptions(
const TensorFlowUnsupportedOperator& op) const {
if (enable_select_tf_ops_) {
return WriteFlexOpOptions(op.tensorflow_node_def);
}
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
if (!node_def.ParseFromString(op.tensorflow_node_def)) {
LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
return std::unique_ptr<flexbuffers::Builder>();
}
if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(op.tensorflow_node_def);
});
fbb->Finish();
LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
const char* key = pair.first.c_str();
const auto& attr = pair.second;
switch (attr.value_case()) {
case ::tensorflow::AttrValue::kS:
fbb->String(key, attr.s());
has_valid_attr = true;
break;
case ::tensorflow::AttrValue::kI:
fbb->Int(key, attr.i());
has_valid_attr = true;
break;
case ::tensorflow::AttrValue::kF:
fbb->Float(key, attr.f());
has_valid_attr = true;
break;
case ::tensorflow::AttrValue::kB:
fbb->Bool(key, attr.b());
has_valid_attr = true;
break;
case tensorflow::AttrValue::kList:
if (attr.list().s_size() > 0) {
auto start = fbb->StartVector(key);
for (const std::string& v : attr.list().s()) {
fbb->Add(v);
}
fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
has_valid_attr = true;
} else if (attr.list().i_size() > 0) {
auto start = fbb->StartVector(key);
for (const int64_t v : attr.list().i()) {
fbb->Add(v);
}
fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
has_valid_attr = true;
} else if (attr.list().f_size() > 0) {
auto start = fbb->StartVector(key);
for (const float v : attr.list().f()) {
fbb->Add(v);
}
fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
has_valid_attr = true;
} else {
LOG(WARNING)
<< "Ignoring unsupported type in list attribute with key '"
<< key << "'";
}
break;
default:
LOG(WARNING) << "Ignoring unsupported attribute type with key '"
<< key << "'";
break;
}
}
if (!has_valid_attr) {
return std::unique_ptr<flexbuffers::Builder>();
}
fbb->EndMap(map_start);
fbb->Finish();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
void ReadOptions(const flexbuffers::Map& m,
TensorFlowUnsupportedOperator* op) const {
::tensorflow::NodeDef node_def;
auto attr = node_def.mutable_attr();
const auto& keys = m.Keys();
for (size_t i = 0; i < keys.size(); ++i) {
const auto key = keys[i].AsKey();
const auto& value = m[key];
switch (value.GetType()) {
case flexbuffers::FBT_STRING:
(*attr)[key].set_s(value.AsString().c_str());
break;
case flexbuffers::FBT_INT:
(*attr)[key].set_i(value.AsInt64());
break;
case flexbuffers::FBT_FLOAT:
(*attr)[key].set_f(value.AsFloat());
break;
case flexbuffers::FBT_BOOL:
(*attr)[key].set_b(value.AsBool());
if (std::string(key) == "_output_quantized") {
op->quantized = value.AsBool();
}
if (std::string(key) ==
"_support_output_type_float_in_quantized_op") {
op->support_output_type_float_in_quantized_op = value.AsBool();
}
break;
case flexbuffers::FBT_VECTOR_INT: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
list->add_i(vector[i].AsInt64());
}
break;
}
case flexbuffers::FBT_VECTOR_FLOAT: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
list->add_f(vector[i].AsFloat());
}
break;
}
case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
list->add_s(vector[i].AsString().str());
}
break;
}
default:
LOG(WARNING) << "Ignoring unsupported attribute type with key '"
<< key << "'";
break;
}
}
node_def.SerializeToString(&op->tensorflow_node_def);
}
int GetVersion(const OperatorSignature& op_signature) const override {
// TODO(ycling): Design and implement a way to plumb the version of
// custom ops.
return 1;
}
private:
const bool enable_select_tf_ops_;
};
class Dequantize
: public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
::tflite::BuiltinOptions_DequantizeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateDequantizeOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class ReverseSequence
: public BuiltinOperator<ReverseSequenceOperator,
::tflite::ReverseSequenceOptions,
::tflite::BuiltinOptions_ReverseSequenceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
op.batch_dim);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->seq_dim = options.seq_dim();
op->batch_dim = options.batch_dim();
}
};
namespace {
// Build a vector containing all the known operators.
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
bool enable_select_tf_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
ops.push_back(
MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
ops.push_back(
MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
ops.push_back(
MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
ops.push_back(
MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
ops.push_back(MakeUnique<AveragePool>(
::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
ops.push_back(
MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
OperatorType::kSpaceToBatchND));
ops.push_back(
MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
OperatorType::kBatchToSpaceND));
ops.push_back(MakeUnique<Concatenation>(
::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
OperatorType::kConv));
ops.push_back(MakeUnique<DepthwiseConvolution>(
::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
OperatorType::kDepthwiseConv));
ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
OperatorType::kDequantize));
ops.push_back(
MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
OperatorType::kFullyConnected));
ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
OperatorType::kGather));
ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
OperatorType::kGatherNd));
ops.push_back(
MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
OperatorType::kL2Normalization));
ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
OperatorType::kL2Pool));
ops.push_back(MakeUnique<LocalResponseNormalization>(
::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
OperatorType::kLocalResponseNormalization));
ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
OperatorType::kMaxPool));
ops.push_back(
MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
ops.push_back(
MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
ops.push_back(
MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
OperatorType::kReshape));
ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
OperatorType::kSoftmax));
ops.push_back(MakeUnique<SpaceToDepth>(
::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
ops.push_back(MakeUnique<DepthToSpace>(
::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
ops.push_back(
MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
OperatorType::kTranspose));
ops.push_back(
MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
ops.push_back(
MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
OperatorType::kReduceProd));
ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
OperatorType::kReduceMax));
ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
OperatorType::kReduceMin));
ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
OperatorType::kAny));
ops.push_back(
MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
ops.push_back(MakeUnique<ResizeNearestNeighbor>(
::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
OperatorType::kResizeNearestNeighbor));
ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
OperatorType::kSqueeze));
ops.push_back(
MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
OperatorType::kSplitV));
ops.push_back(MakeUnique<StridedSlice>(
::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
OperatorType::kTopK_V2));
ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
OperatorType::kLstmCell));
ops.push_back(
MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
OperatorType::kArgMax));
ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
OperatorType::kArgMin));
ops.push_back(
MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
OperatorType::kExpandDims));
ops.push_back(MakeUnique<TransposeConv>(
::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
ops.push_back(MakeUnique<SparseToDense>(
::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
ops.push_back(
MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
OperatorType::kFakeQuant));
ops.push_back(
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
OperatorType::kUnidirectionalSequenceLstm));
ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
OperatorType::kBidirectionalSequenceLstm));
ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
OperatorType::kBidirectionalSequenceRnn));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
OperatorType::kUnpack));
ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
OperatorType::kLeakyRelu));
ops.push_back(MakeUnique<SquaredDifference>(
::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
OperatorType::kSquaredDifference));
ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
OperatorType::kMirrorPad));
ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
OperatorType::kUnique));
ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
OperatorType::kUnidirectionalSequenceRnn));
ops.push_back(
MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
ops.push_back(
MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
OperatorType::kReverseSequence));
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
// Custom Operators.
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
OperatorType::kUnsupported,
enable_select_tf_ops));
// SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
// been modified to also export builtins. As TOCO evolved we added warnings
// when custom ops are exported but SimpleOperator bypasses thoses. To
// prevent user confusion we are settling on using SimpleOperator only for
// builtins.
ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
::tflite::BuiltinOperator_ELU, OperatorType::kElu));
ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
::tflite::BuiltinOperator_EXP, OperatorType::kExp));
ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
::tflite::BuiltinOperator_COS, OperatorType::kCos));
ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
::tflite::BuiltinOperator_LESS, OperatorType::kLess));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
::tflite::BuiltinOperator_POW, OperatorType::kPow));
ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
ops.emplace_back(new SimpleOperator<FloorDivOperator>(
::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
ops.emplace_back(new SimpleOperator<FloorModOperator>(
::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
ops.emplace_back(new SimpleOperator<RangeOperator>(
::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
// Element-wise operator
ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
::tflite::BuiltinOperator_SIN, OperatorType::kSin));
ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
::tflite::BuiltinOperator_LOG, OperatorType::kLog));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
::tflite::BuiltinOperator_FILL, OperatorType::kFill));
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
::tflite::BuiltinOperator_RANK, OperatorType::kRank));
ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>(
::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd));
return ops;
}
} // namespace
// LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
bool enable_select_tf_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
BuildOperatorList(enable_select_tf_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
return result;
}
std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
bool enable_select_tf_ops) {
std::map<std::string, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
BuildOperatorList(enable_select_tf_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
return result;
}
bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
const std::string& tensorflow_op_name) {
// If Flex ops aren't allow at all, simply return false.
if (!enable_select_tf_ops) {
return false;
}
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
// it and it has been allowlisted, export the op as an Flex op. Otherwise,
// export it as a regular custom op.
const tensorflow::OpDef* op_def = nullptr;
if (!tensorflow::OpRegistry::Global()
->LookUpOpDef(tensorflow_op_name, &op_def)
.ok()) {
return false;
}
if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
LOG(WARNING) << "Op " << tensorflow_op_name
<< " is a valid TensorFlow op but has not been allowlisted for"
" the TensorFlow Lite flex op set.";
return false;
}
return true;
}
} // namespace tflite
} // namespace toco