STT-tensorflow/tensorflow/lite/toco/tflite/operator.cc
A. Unique TensorFlower 16cb89bd7b Qualify uses of std::string
PiperOrigin-RevId: 317319501
Change-Id: Ib75a31ad89fa1a6bda81450f2ab5ba07d7338ada
2020-06-19 09:21:51 -07:00

2141 lines
84 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/toco/tflite/operator.h"
#include <map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/util/ptr_util.h"
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
// graph_transformation module.
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/lite/toco/tflite/custom_operator.h"
#include "tensorflow/lite/toco/tflite/simple_operator.h"
#include "tensorflow/lite/toco/tflite/types.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
namespace toco {
namespace tflite {
// LINT.IfChange
::tflite::TensorType GetTensorType(const ArrayDataType type) {
const std::map<ArrayDataType, ::tflite::TensorType> tensor_type_map = {
{ArrayDataType::kBool, ::tflite::TensorType_BOOL},
{ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
{ArrayDataType::kInt8, ::tflite::TensorType_INT8},
{ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
{ArrayDataType::kInt16, ::tflite::TensorType_INT16},
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kString, ::tflite::TensorType_STRING},
{ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
{ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
{ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};
auto it = tensor_type_map.find(type);
if (it != tensor_type_map.end()) {
return it->second;
}
return static_cast<::tflite::TensorType>(-1);
}
::tflite::OpSignature GetVersioningOpSig(
const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
std::vector<::tflite::TensorType> input_types, output_types;
for (const auto& input_name : op_signature.op->inputs) {
::tflite::TensorType input_type = static_cast<::tflite::TensorType>(-1);
if (op_signature.model->HasArray(input_name)) {
const Array& input_array = op_signature.model->GetArray(input_name);
input_type = GetTensorType(input_array.data_type);
}
input_types.push_back(input_type);
}
for (const auto& output_name : op_signature.op->outputs) {
::tflite::TensorType output_type = static_cast<::tflite::TensorType>(-1);
if (op_signature.model->HasArray(output_name)) {
const Array& output_array = op_signature.model->GetArray(output_name);
output_type = GetTensorType(output_array.data_type);
}
output_types.push_back(output_type);
}
return ::tflite::OpSignature{op, input_types, output_types};
}
class AveragePool
: public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
::tflite::BuiltinOptions_Pool2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
op.stride_height, op.kwidth,
op.kheight, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->kwidth = options.filter_width();
op->kheight = options.filter_height();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class Convolution
: public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
::tflite::BuiltinOptions_Conv2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
op.stride_height, activation_function,
op.dilation_width_factor,
op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->dilation_width_factor = options.dilation_w_factor();
op->dilation_height_factor = options.dilation_h_factor();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class DepthwiseConvolution
: public BuiltinOperator<DepthwiseConvOperator,
::tflite::DepthwiseConv2DOptions,
::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDepthwiseConv2DOptions(
*builder, padding, op.stride_width, op.stride_height,
op.depth_multiplier, activation_function, op.dilation_width_factor,
op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->depth_multiplier = options.depth_multiplier();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
op->dilation_width_factor = options.dilation_w_factor();
op->dilation_height_factor = options.dilation_h_factor();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& conv_op =
static_cast<const DepthwiseConvOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.depthwise_conv_2d.dilation_w_factor =
conv_op.dilation_width_factor;
op_sig.options.depthwise_conv_2d.dilation_h_factor =
conv_op.dilation_height_factor;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
::tflite::BuiltinOptions_AddOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateAddOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
::tflite::BuiltinOptions_AddNOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateAddNOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class SpaceToBatchND
: public BuiltinOperator<SpaceToBatchNDOperator,
::tflite::SpaceToBatchNDOptions,
::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSpaceToBatchNDOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.single_input_op.num_dims =
input_array.shape().dimensions_count();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
::tflite::BuiltinOptions_SubOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateSubOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input1_name = op_signature.op->inputs[0];
const std::string& input2_name = op_signature.op->inputs[1];
const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
if (input1_array.has_shape() && input2_array.has_shape()) {
op_sig.options.broadcast.num_dims =
std::max(input1_array.shape().dimensions_count(),
input2_array.shape().dimensions_count());
op_sig.options.broadcast.need_broadcast =
(input1_array.shape() != input2_array.shape());
}
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
::tflite::BuiltinOptions_DivOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDivOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input1_name = op_signature.op->inputs[0];
const std::string& input2_name = op_signature.op->inputs[1];
const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
if (input1_array.has_shape() && input2_array.has_shape()) {
op_sig.options.broadcast.num_dims =
std::max(input1_array.shape().dimensions_count(),
input2_array.shape().dimensions_count());
op_sig.options.broadcast.need_broadcast =
(input1_array.shape() != input2_array.shape());
}
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class BatchToSpaceND
: public BuiltinOperator<BatchToSpaceNDOperator,
::tflite::BatchToSpaceNDOptions,
::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateBatchToSpaceNDOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.single_input_op.num_dims =
input_array.shape().dimensions_count();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
::tflite::BuiltinOptions_CastOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateCastOptions(*builder,
DataType::Serialize(op.src_data_type),
DataType::Serialize(op.dst_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->src_data_type = DataType::Deserialize(options.in_data_type());
op->dst_data_type = DataType::Deserialize(options.out_data_type());
}
};
class Concatenation
: public BuiltinOperator<ConcatenationOperator,
::tflite::ConcatenationOptions,
::tflite::BuiltinOptions_ConcatenationOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateConcatenationOptions(*builder, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->axis = options.axis();
}
};
class DepthToSpace
: public BuiltinOperator<DepthToSpaceOperator,
::tflite::DepthToSpaceOptions,
::tflite::BuiltinOptions_DepthToSpaceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->block_size = options.block_size();
}
};
class FakeQuant
: public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
::tflite::BuiltinOptions_FakeQuantOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateFakeQuantOptions(
*builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
auto* minmax = new MinMax;
minmax->min = options.min();
minmax->max = options.max();
op->minmax.reset(minmax);
op->num_bits = options.num_bits();
op->narrow_range = options.narrow_range();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.fakequant.narrow_range = fq_op.narrow_range;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class FullyConnected
: public BuiltinOperator<FullyConnectedOperator,
::tflite::FullyConnectedOptions,
::tflite::BuiltinOptions_FullyConnectedOptions> {
public:
using BuiltinOperator::BuiltinOperator;
::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
FullyConnectedWeightsFormat fmt) const {
switch (fmt) {
case FullyConnectedWeightsFormat::kDefault:
return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
case FullyConnectedWeightsFormat::kShuffled4x16Int8:
return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
default:
LOG(ERROR) << "Unhandled FC weights format";
return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
}
}
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateFullyConnectedOptions(
*builder, activation_function, GetWeightFormat(op.weights_format));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
switch (options.weights_format()) {
case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
op->weights_format = FullyConnectedWeightsFormat::kDefault;
break;
case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
break;
default:
LOG(ERROR) << "Unhandled FC weights format";
op->weights_format = FullyConnectedWeightsFormat::kDefault;
}
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& fc_op =
static_cast<const FullyConnectedOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
op_sig.options.fully_connected.weights_format =
GetWeightFormat(fc_op.weights_format);
op_sig.options.fully_connected.sparse_weight = false;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
::tflite::BuiltinOptions_GatherOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
int axis = op.axis ? op.axis.value() : 0;
return ::tflite::CreateGatherOptions(*builder, axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->axis = {options.axis()};
}
};
class GatherNd
: public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
::tflite::BuiltinOptions_GatherNdOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateGatherNdOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
::tflite::BuiltinOptions_SVDFOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
op->rank = options.rank();
}
};
class L2Normalization
: public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
::tflite::BuiltinOptions_L2NormOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateL2NormOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
::tflite::BuiltinOptions_Pool2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
op.stride_height, op.kwidth,
op.kheight, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->kwidth = options.filter_width();
op->kheight = options.filter_height();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class LocalResponseNormalization
: public BuiltinOperator<
LocalResponseNormalizationOperator,
::tflite::LocalResponseNormalizationOptions,
::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateLocalResponseNormalizationOptions(
*builder, op.range, op.bias, op.alpha, op.beta);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->range = options.radius();
op->bias = options.bias();
op->alpha = options.alpha();
op->beta = options.beta();
}
};
class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
::tflite::BuiltinOptions_Pool2DOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
op.stride_height, op.kwidth,
op.kheight, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
op->kwidth = options.filter_width();
op->kheight = options.filter_height();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
};
class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
::tflite::BuiltinOptions_MulOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateMulOptions(*builder, activation_function);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input1_name = op_signature.op->inputs[0];
const std::string& input2_name = op_signature.op->inputs[1];
const std::string& output_name = op_signature.op->outputs[0];
const Array& input1_array = op_signature.model->GetArray(input1_name);
const Array& input2_array = op_signature.model->GetArray(input2_name);
const Array& output_array = op_signature.model->GetArray(output_name);
const auto& input1_quant = input1_array.quantization_params;
const auto& input2_quant = input2_array.quantization_params;
const auto& output_quant = output_array.quantization_params;
const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
const float output_scale = output_quant ? output_quant->scale : 0.0f;
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.mul.input1_scale = input1_scale;
op_sig.options.mul.input2_scale = input2_scale;
op_sig.options.mul.output_scale = output_scale;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
::tflite::BuiltinOptions_PadOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreatePadOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Tile
: public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
::tflite::BuiltinOptions_TileOptions> {
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateTileOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
::tflite::BuiltinOptions_PadV2Options> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreatePadV2Options(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Reshape
: public BuiltinOperator<TensorFlowReshapeOperator,
::tflite::ReshapeOptions,
::tflite::BuiltinOptions_ReshapeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReshapeOptions(*builder,
builder->CreateVector(op.shape));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->shape.insert(op->shape.end(), options.new_shape()->begin(),
options.new_shape()->end());
}
};
class Softmax
: public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
::tflite::BuiltinOptions_SoftmaxOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->beta = options.beta();
}
};
class SpaceToDepth
: public BuiltinOperator<SpaceToDepthOperator,
::tflite::SpaceToDepthOptions,
::tflite::BuiltinOptions_SpaceToDepthOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->block_size = options.block_size();
}
};
class Transpose
: public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
::tflite::BuiltinOptions_TransposeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateTransposeOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
::tflite::BuiltinOptions_LSTMOptions> {
public:
using BuiltinOperator::BuiltinOperator;
::tflite::LSTMKernelType GetKernelType(
LstmCellOperator::KernelType type) const {
switch (type) {
case LstmCellOperator::KERNEL_BASIC:
return ::tflite::LSTMKernelType_BASIC;
break;
case LstmCellOperator::KERNEL_FULL:
return ::tflite::LSTMKernelType_FULL;
break;
default:
LOG(ERROR) << "Unhandled Kernel Type";
return static_cast<::tflite::LSTMKernelType>(-1);
}
}
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
/*proj_clip=*/0.0, kernel_type);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
CHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
switch (options.kernel_type()) {
case ::tflite::LSTMKernelType_BASIC:
op->kernel_type = LstmCellOperator::KERNEL_BASIC;
break;
case ::tflite::LSTMKernelType_FULL:
op->kernel_type = LstmCellOperator::KERNEL_FULL;
break;
}
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& lstm_op =
static_cast<const LstmCellOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.lstm.kernel_type = GetKernelType(lstm_op.kernel_type);
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
switch (lstm_op.kernel_type) {
case LstmCellOperator::KERNEL_FULL: {
mutating_input_variables[kInputActivationStateTensor] = true;
mutating_input_variables[kInputCellStateTensor] = true;
break;
}
case LstmCellOperator::KERNEL_BASIC: {
mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
break;
}
}
return mutating_input_variables;
}
};
class UnidirectionalSequenceLstm
: public BuiltinOperator<
UnidirectionalSequenceLstmOperator,
::tflite::UnidirectionalSequenceLSTMOptions,
::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
/*proj_clip=*/0.0,
/*time_major=*/true);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
mutating_input_variables[kInputActivationStateTensor] = true;
mutating_input_variables[kInputCellStateTensor] = true;
return mutating_input_variables;
}
};
class BidirectionalSequenceLstm
: public BuiltinOperator<
BidirectionalSequenceLstmOperator,
::tflite::BidirectionalSequenceLSTMOptions,
::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateBidirectionalSequenceLSTMOptions(
*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
/*proj_clip=*/0.0,
/*merge_outputs=*/op.merge_outputs,
/*time_major=*/true);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
op->merge_outputs = options.merge_outputs();
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
// Forward input activation state.
mutating_input_variables[35] = true;
// Forward input cell state.
mutating_input_variables[36] = true;
// Backward input activation state.
mutating_input_variables[37] = true;
// Backward input cell state.
mutating_input_variables[38] = true;
return mutating_input_variables;
}
};
class BidirectionalSequenceRnn
: public BuiltinOperator<
BidirectionalSequenceRnnOperator,
::tflite::BidirectionalSequenceRNNOptions,
::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateBidirectionalSequenceRNNOptions(
*builder, /*time_major=*/true,
/*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*merge_outputs=*/op.merge_outputs);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
op->merge_outputs = options.merge_outputs();
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
// Forward hidden state.
mutating_input_variables[4] = true;
// Backward hidden state.
mutating_input_variables[8] = true;
return mutating_input_variables;
}
};
class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class Sum
: public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceMax
: public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceMin
: public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceProd
: public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ReduceAny
: public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
};
class ResizeBilinear
: public BuiltinOperator<ResizeBilinearOperator,
::tflite::ResizeBilinearOptions,
::tflite::BuiltinOptions_ResizeBilinearOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
op.half_pixel_centers);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->align_corners = options.align_corners();
op->half_pixel_centers = options.half_pixel_centers();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& resize_bilinear_op =
static_cast<const ResizeBilinearOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize.half_pixel_centers =
resize_bilinear_op.half_pixel_centers;
op_sig.options.resize.align_corners = resize_bilinear_op.align_corners;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class ResizeNearestNeighbor
: public BuiltinOperator<
ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateResizeNearestNeighborOptions(
*builder, op.align_corners, op.half_pixel_centers);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->align_corners = options.align_corners();
op->half_pixel_centers = options.half_pixel_centers();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& resize_nn_op =
static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize.half_pixel_centers = resize_nn_op.half_pixel_centers;
op_sig.options.resize.align_corners = resize_nn_op.align_corners;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class Squeeze
: public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
::tflite::BuiltinOptions_SqueezeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->squeeze_dims.insert(op->squeeze_dims.end(),
options.squeeze_dims()->begin(),
options.squeeze_dims()->end());
}
};
class Split
: public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
::tflite::BuiltinOptions_SplitOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSplitOptions(*builder, op.num_split);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->num_split = options.num_splits();
}
};
class SplitV
: public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
::tflite::BuiltinOptions_SplitVOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSplitVOptions(*builder, op.num_split);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->num_split = options.num_splits();
}
};
class StridedSlice
: public BuiltinOperator<StridedSliceOperator,
::tflite::StridedSliceOptions,
::tflite::BuiltinOptions_StridedSliceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateStridedSliceOptions(
*builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
op.new_axis_mask, op.shrink_axis_mask);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->begin_mask = options.begin_mask();
op->end_mask = options.end_mask();
op->ellipsis_mask = options.ellipsis_mask();
op->new_axis_mask = options.new_axis_mask();
op->shrink_axis_mask = options.shrink_axis_mask();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& ss_op =
static_cast<const StridedSliceOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.single_input_op.num_dims = ss_op.start_indices.size();
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
::tflite::BuiltinOptions_TopKV2Options> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateTopKV2Options(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
::tflite::BuiltinOptions_ArgMaxOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateArgMaxOptions(
*builder, DataType::Serialize(op.output_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.output_type());
}
};
class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
::tflite::BuiltinOptions_ArgMinOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateArgMinOptions(
*builder, DataType::Serialize(op.output_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.output_type());
}
};
class TransposeConv
: public BuiltinOperator<TransposeConvOperator,
::tflite::TransposeConvOptions,
::tflite::BuiltinOptions_TransposeConvOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto padding = Padding::Serialize(op.padding.type);
return ::tflite::CreateTransposeConvOptions(
*builder, padding, op.stride_width, op.stride_height);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->padding.type = Padding::Deserialize(options.padding());
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
}
};
class SparseToDense
: public BuiltinOperator<SparseToDenseOperator,
::tflite::SparseToDenseOptions,
::tflite::BuiltinOptions_SparseToDenseOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->validate_indices = options.validate_indices();
}
};
class ExpandDims
: public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
::tflite::BuiltinOptions_ExpandDimsOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateExpandDimsOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
::tflite::BuiltinOptions_PackOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->values_count = options.values_count();
op->axis = options.axis();
}
};
class Shape
: public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
::tflite::BuiltinOptions_ShapeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateShapeOptions(
*builder, DataType::Serialize(op.output_data_type));
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.out_type());
}
};
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
::tflite::BuiltinOptions_OneHotOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateOneHotOptions(*builder, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->axis = options.axis();
}
};
class CTCBeamSearchDecoder
: public CustomOperator<CTCBeamSearchDecoderOperator> {
public:
using CustomOperator::CustomOperator;
void WriteOptions(const TocoOperator& op,
flexbuffers::Builder* fbb) const override {
fbb->Int("beam_width", op.beam_width);
fbb->Int("top_paths", op.top_paths);
fbb->Bool("merge_repeated", op.merge_repeated);
}
void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
op->beam_width = m["beam_width"].AsInt32();
op->top_paths = m["top_paths"].AsInt32();
op->merge_repeated = m["merge_repeated"].AsBool();
}
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
::tflite::BuiltinOptions_UnpackOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->num = options.num();
op->axis = options.axis();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const std::string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8/uint8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8 ||
input_array.data_type == ArrayDataType::kUint8) {
return 2;
}
// If the op take bool input, it is version 3.
if (input_array.data_type == ArrayDataType::kBool) {
return 3;
}
return 1;
}
};
class LeakyRelu
: public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
::tflite::BuiltinOptions_LeakyReluOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->alpha = options.alpha();
}
};
class SquaredDifference
: public BuiltinOperator<
SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
::tflite::BuiltinOptions_SquaredDifferenceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSquaredDifferenceOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class MirrorPad
: public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
::tflite::BuiltinOptions_MirrorPadOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateMirrorPadOptions(
*builder, op.mode == MirrorPadMode::kReflect
? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
: ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
? MirrorPadMode::kReflect
: MirrorPadMode::kSymmetric;
}
};
class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
::tflite::BuiltinOptions_UniqueOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
return ::tflite::CreateUniqueOptions(
*builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
? ::tflite::TensorType::TensorType_INT64
: ::tflite::TensorType_INT32);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
unique_op->idx_out_type =
options.idx_out_type() == ::tflite::TensorType_INT64
? toco::ArrayDataType::kInt64
: toco::ArrayDataType::kInt32;
}
};
class UnidirectionalSequenceRnn
: public BuiltinOperator<UnidirectionalSequenceRnnOperator,
::tflite::SequenceRNNOptions,
::tflite::BuiltinOptions_SequenceRNNOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateSequenceRNNOptions(
*builder, /*time_major=*/true,
/*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
// Only support tanh activation, so check that tflite type is tanh.
DCHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
}
std::vector<bool> GetMutatingInputVariables(
const Operator& op) const override {
std::vector<bool> mutating_input_variables(op.inputs.size(), false);
mutating_input_variables[4] = true;
return mutating_input_variables;
}
};
class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
::tflite::BuiltinOptions_WhereOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateWhereOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
const std::string& tensorflow_node_def) {
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
if (!node_def.ParseFromString(tensorflow_node_def)) {
LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
return {};
}
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(tensorflow_node_def);
});
fbb->Finish();
LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const std::string& name, OperatorType type,
bool enable_select_tf_ops)
: BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
auto fbb =
WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
if (fbb) {
return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
} else {
return Options::Custom(0);
}
}
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
// Deserializing Flex ops doesn't work now.
// TODO(ycling): Revisit and decide if we should fix the flow for importing
// TFLite models with Flex ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
flexbuffers::GetRoot(custom_options->data(), custom_options->size())
.AsMap();
ReadOptions(flexbuffer_map, op.get());
}
return std::unique_ptr<Operator>(op.release());
}
std::unique_ptr<flexbuffers::Builder> WriteOptions(
const TensorFlowUnsupportedOperator& op) const {
if (enable_select_tf_ops_) {
return WriteFlexOpOptions(op.tensorflow_node_def);
}
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
if (!node_def.ParseFromString(op.tensorflow_node_def)) {
LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
return std::unique_ptr<flexbuffers::Builder>();
}
if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(op.tensorflow_node_def);
});
fbb->Finish();
LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
const char* key = pair.first.c_str();
const auto& attr = pair.second;
switch (attr.value_case()) {
case ::tensorflow::AttrValue::kS:
fbb->String(key, attr.s());
has_valid_attr = true;
break;
case ::tensorflow::AttrValue::kI:
fbb->Int(key, attr.i());
has_valid_attr = true;
break;
case ::tensorflow::AttrValue::kF:
fbb->Float(key, attr.f());
has_valid_attr = true;
break;
case ::tensorflow::AttrValue::kB:
fbb->Bool(key, attr.b());
has_valid_attr = true;
break;
case tensorflow::AttrValue::kList:
if (attr.list().s_size() > 0) {
auto start = fbb->StartVector(key);
for (const std::string& v : attr.list().s()) {
fbb->Add(v);
}
fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
has_valid_attr = true;
} else if (attr.list().i_size() > 0) {
auto start = fbb->StartVector(key);
for (const int64_t v : attr.list().i()) {
fbb->Add(v);
}
fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
has_valid_attr = true;
} else if (attr.list().f_size() > 0) {
auto start = fbb->StartVector(key);
for (const float v : attr.list().f()) {
fbb->Add(v);
}
fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
has_valid_attr = true;
} else {
LOG(WARNING)
<< "Ignoring unsupported type in list attribute with key '"
<< key << "'";
}
break;
default:
LOG(WARNING) << "Ignoring unsupported attribute type with key '"
<< key << "'";
break;
}
}
if (!has_valid_attr) {
return std::unique_ptr<flexbuffers::Builder>();
}
fbb->EndMap(map_start);
fbb->Finish();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
void ReadOptions(const flexbuffers::Map& m,
TensorFlowUnsupportedOperator* op) const {
::tensorflow::NodeDef node_def;
auto attr = node_def.mutable_attr();
const auto& keys = m.Keys();
for (size_t i = 0; i < keys.size(); ++i) {
const auto key = keys[i].AsKey();
const auto& value = m[key];
switch (value.GetType()) {
case flexbuffers::FBT_STRING:
(*attr)[key].set_s(value.AsString().c_str());
break;
case flexbuffers::FBT_INT:
(*attr)[key].set_i(value.AsInt64());
break;
case flexbuffers::FBT_FLOAT:
(*attr)[key].set_f(value.AsFloat());
break;
case flexbuffers::FBT_BOOL:
(*attr)[key].set_b(value.AsBool());
if (std::string(key) == "_output_quantized") {
op->quantized = value.AsBool();
}
if (std::string(key) ==
"_support_output_type_float_in_quantized_op") {
op->support_output_type_float_in_quantized_op = value.AsBool();
}
break;
case flexbuffers::FBT_VECTOR_INT: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
list->add_i(vector[i].AsInt64());
}
break;
}
case flexbuffers::FBT_VECTOR_FLOAT: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
list->add_f(vector[i].AsFloat());
}
break;
}
case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
list->add_s(vector[i].AsString().str());
}
break;
}
default:
LOG(WARNING) << "Ignoring unsupported attribute type with key '"
<< key << "'";
break;
}
}
node_def.SerializeToString(&op->tensorflow_node_def);
}
int GetVersion(const OperatorSignature& op_signature) const override {
// TODO(ycling): Design and implement a way to plumb the version of
// custom ops.
return 1;
}
private:
const bool enable_select_tf_ops_;
};
class Dequantize
: public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
::tflite::BuiltinOptions_DequantizeOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateDequantizeOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
};
class ReverseSequence
: public BuiltinOperator<ReverseSequenceOperator,
::tflite::ReverseSequenceOptions,
::tflite::BuiltinOptions_ReverseSequenceOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
op.batch_dim);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->seq_dim = options.seq_dim();
op->batch_dim = options.batch_dim();
}
};
namespace {
// Build a vector containing all the known operators.
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
bool enable_select_tf_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
ops.push_back(
MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
ops.push_back(
MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
ops.push_back(
MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
ops.push_back(
MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
ops.push_back(MakeUnique<AveragePool>(
::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
ops.push_back(
MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
OperatorType::kSpaceToBatchND));
ops.push_back(
MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
OperatorType::kBatchToSpaceND));
ops.push_back(MakeUnique<Concatenation>(
::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
OperatorType::kConv));
ops.push_back(MakeUnique<DepthwiseConvolution>(
::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
OperatorType::kDepthwiseConv));
ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
OperatorType::kDequantize));
ops.push_back(
MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
OperatorType::kFullyConnected));
ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
OperatorType::kGather));
ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
OperatorType::kGatherNd));
ops.push_back(
MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
OperatorType::kL2Normalization));
ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
OperatorType::kL2Pool));
ops.push_back(MakeUnique<LocalResponseNormalization>(
::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
OperatorType::kLocalResponseNormalization));
ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
OperatorType::kMaxPool));
ops.push_back(
MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
ops.push_back(
MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
ops.push_back(
MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
OperatorType::kReshape));
ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
OperatorType::kSoftmax));
ops.push_back(MakeUnique<SpaceToDepth>(
::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
ops.push_back(MakeUnique<DepthToSpace>(
::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
ops.push_back(
MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
OperatorType::kTranspose));
ops.push_back(
MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
ops.push_back(
MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
OperatorType::kReduceProd));
ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
OperatorType::kReduceMax));
ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
OperatorType::kReduceMin));
ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
OperatorType::kAny));
ops.push_back(
MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
ops.push_back(MakeUnique<ResizeNearestNeighbor>(
::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
OperatorType::kResizeNearestNeighbor));
ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
OperatorType::kSqueeze));
ops.push_back(
MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
OperatorType::kSplitV));
ops.push_back(MakeUnique<StridedSlice>(
::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
OperatorType::kTopK_V2));
ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
OperatorType::kLstmCell));
ops.push_back(
MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
OperatorType::kArgMax));
ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
OperatorType::kArgMin));
ops.push_back(
MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
OperatorType::kExpandDims));
ops.push_back(MakeUnique<TransposeConv>(
::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
ops.push_back(MakeUnique<SparseToDense>(
::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
ops.push_back(
MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
OperatorType::kFakeQuant));
ops.push_back(
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
OperatorType::kUnidirectionalSequenceLstm));
ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
OperatorType::kBidirectionalSequenceLstm));
ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
OperatorType::kBidirectionalSequenceRnn));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
OperatorType::kUnpack));
ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
OperatorType::kLeakyRelu));
ops.push_back(MakeUnique<SquaredDifference>(
::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
OperatorType::kSquaredDifference));
ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
OperatorType::kMirrorPad));
ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
OperatorType::kUnique));
ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
OperatorType::kUnidirectionalSequenceRnn));
ops.push_back(
MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
ops.push_back(
MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
OperatorType::kReverseSequence));
ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
// Custom Operators.
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
OperatorType::kUnsupported,
enable_select_tf_ops));
// SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
// been modified to also export builtins. As TOCO evolved we added warnings
// when custom ops are exported but SimpleOperator bypasses thoses. To
// prevent user confusion we are settling on using SimpleOperator only for
// builtins.
ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
::tflite::BuiltinOperator_ELU, OperatorType::kElu));
ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
::tflite::BuiltinOperator_EXP, OperatorType::kExp));
ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
::tflite::BuiltinOperator_COS, OperatorType::kCos));
ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
::tflite::BuiltinOperator_LESS, OperatorType::kLess));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
::tflite::BuiltinOperator_POW, OperatorType::kPow));
ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
ops.emplace_back(new SimpleOperator<FloorDivOperator>(
::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
ops.emplace_back(new SimpleOperator<FloorModOperator>(
::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
ops.emplace_back(new SimpleOperator<RangeOperator>(
::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
// Element-wise operator
ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
::tflite::BuiltinOperator_SIN, OperatorType::kSin));
ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
::tflite::BuiltinOperator_LOG, OperatorType::kLog));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
::tflite::BuiltinOperator_FILL, OperatorType::kFill));
ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
::tflite::BuiltinOperator_RANK, OperatorType::kRank));
ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>(
::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd));
return ops;
}
} // namespace
// LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
bool enable_select_tf_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
BuildOperatorList(enable_select_tf_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
return result;
}
std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
bool enable_select_tf_ops) {
std::map<std::string, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
BuildOperatorList(enable_select_tf_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
return result;
}
bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
const std::string& tensorflow_op_name) {
// If Flex ops aren't allow at all, simply return false.
if (!enable_select_tf_ops) {
return false;
}
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
// it and it has been whitelisted, export the op as an Flex op. Otherwise,
// export it as a regular custom op.
const tensorflow::OpDef* op_def = nullptr;
if (!tensorflow::OpRegistry::Global()
->LookUpOpDef(tensorflow_op_name, &op_def)
.ok()) {
return false;
}
if (!::tflite::flex::IsWhitelistedFlexOp(tensorflow_op_name)) {
LOG(WARNING) << "Op " << tensorflow_op_name
<< " is a valid TensorFlow op but has not been whitelisted for"
" the TensorFlow Lite flex op set.";
return false;
}
return true;
}
} // namespace tflite
} // namespace toco