STT-tensorflow/tensorflow/lite/schema/schema.fbs
Suharsh Sivakumar 058717dd41 QuantizationDetails and CustomQuantization to schema.
Allows experimenting with new quantization techniques.

PiperOrigin-RevId: 221104574
2018-11-12 09:32:52 -08:00

759 lines
17 KiB
Plaintext

// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Revision History
// Version 0: Initial version.
// Version 1: Add subgraphs to schema.
// Version 2: Rename operators to conform to NN API.
// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
namespace tflite;
// This corresponds to the version.
file_identifier "TFL3";
// File extension of any written files.
file_extension "tflite";
// IMPORTANT: All new members of tables, enums and unions must be added at the
// end to ensure backwards compatibility.
// The type of data stored in a tensor.
enum TensorType : byte {
FLOAT32 = 0,
FLOAT16 = 1,
INT32 = 2,
UINT8 = 3,
INT64 = 4,
STRING = 5,
BOOL = 6,
INT16 = 7,
COMPLEX64 = 8,
INT8 = 9,
}
// Custom quantization parameters for experimenting with new quantization
// techniques.
table CustomQuantization {
custom:[byte];
}
// Represents a specific quantization technique's parameters.
union QuantizationDetails {
CustomQuantization,
}
// Parameters for converting a quantized tensor back to float.
table QuantizationParameters {
// These four parameters are the asymmetric linear quantization parameters.
// Given a quantized value q, the corresponding float value f should be:
// f = scale * (q - zero_point)
// For other quantization types, the QuantizationDetails below is used.
min:[float]; // For importing back into tensorflow.
max:[float]; // For importing back into tensorflow.
scale:[float]; // For dequantizing the tensor's values.
zero_point:[long];
// If this is not none, the quantization parameters above are ignored and the
// value of the QuantizationDetails union below should be used.
details:QuantizationDetails;
}
table Tensor {
// The tensor shape. The meaning of each entry is operator-specific but
// builtin ops use: [batch size, height, width, number of channels] (That's
// Tensorflow's NHWC).
shape:[int];
type:TensorType;
// An index that refers to the buffers table at the root of the model. Or,
// if there is no data buffer associated (i.e. intermediate results), then
// this is 0 (which refers to an always existent empty buffer).
//
// The data_buffer itself is an opaque container, with the assumption that the
// target device is little-endian. In addition, all builtin operators assume
// the memory is ordered such that if `shape` is [4, 3, 2], then index
// [i, j, k] maps to data_buffer[i*3*2 + j*2 + k].
buffer:uint;
name:string; // For debugging and importing back into tensorflow.
quantization:QuantizationParameters; // Optional.
is_variable:bool = false;
}
// A list of builtin operators. Builtin operators are slightly faster than custom
// ones, but not by much. Moreover, while custom operators accept an opaque
// object containing configuration parameters, builtins have a predetermined
// set of acceptable options.
enum BuiltinOperator : byte {
ADD = 0,
AVERAGE_POOL_2D = 1,
CONCATENATION = 2,
CONV_2D = 3,
DEPTHWISE_CONV_2D = 4,
// DEPTH_TO_SPACE = 5,
DEQUANTIZE = 6,
EMBEDDING_LOOKUP = 7,
FLOOR = 8,
FULLY_CONNECTED = 9,
HASHTABLE_LOOKUP = 10,
L2_NORMALIZATION = 11,
L2_POOL_2D = 12,
LOCAL_RESPONSE_NORMALIZATION = 13,
LOGISTIC = 14,
LSH_PROJECTION = 15,
LSTM = 16,
MAX_POOL_2D = 17,
MUL = 18,
RELU = 19,
// NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed
// since different model developers use RELU1 in different ways. Never
// create another op called RELU1.
RELU_N1_TO_1 = 20,
RELU6 = 21,
RESHAPE = 22,
RESIZE_BILINEAR = 23,
RNN = 24,
SOFTMAX = 25,
SPACE_TO_DEPTH = 26,
SVDF = 27,
TANH = 28,
// TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS
CONCAT_EMBEDDINGS = 29,
SKIP_GRAM = 30,
CALL = 31,
CUSTOM = 32,
EMBEDDING_LOOKUP_SPARSE = 33,
PAD = 34,
UNIDIRECTIONAL_SEQUENCE_RNN = 35,
GATHER = 36,
BATCH_TO_SPACE_ND = 37,
SPACE_TO_BATCH_ND = 38,
TRANSPOSE = 39,
MEAN = 40,
SUB = 41,
DIV = 42,
SQUEEZE = 43,
UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
STRIDED_SLICE = 45,
BIDIRECTIONAL_SEQUENCE_RNN = 46,
EXP = 47,
TOPK_V2 = 48,
SPLIT = 49,
LOG_SOFTMAX = 50,
// DELEGATE is a special op type for the operations which are delegated to
// other backends.
// WARNING: Experimental interface, subject to change
DELEGATE = 51,
BIDIRECTIONAL_SEQUENCE_LSTM = 52,
CAST = 53,
PRELU = 54,
MAXIMUM = 55,
ARG_MAX = 56,
MINIMUM = 57,
LESS = 58,
NEG = 59,
PADV2 = 60,
GREATER = 61,
GREATER_EQUAL = 62,
LESS_EQUAL = 63,
SELECT = 64,
SLICE = 65,
SIN = 66,
TRANSPOSE_CONV = 67,
SPARSE_TO_DENSE = 68,
TILE = 69,
EXPAND_DIMS = 70,
EQUAL = 71,
NOT_EQUAL = 72,
LOG = 73,
SUM = 74,
SQRT = 75,
RSQRT = 76,
SHAPE = 77,
POW = 78,
ARG_MIN = 79,
FAKE_QUANT = 80,
REDUCE_PROD = 81,
REDUCE_MAX = 82,
PACK = 83,
LOGICAL_OR = 84,
ONE_HOT = 85,
LOGICAL_AND = 86,
LOGICAL_NOT = 87,
UNPACK = 88,
REDUCE_MIN = 89,
FLOOR_DIV = 90,
REDUCE_ANY = 91,
SQUARE = 92,
ZEROS_LIKE = 93,
FILL = 94,
FLOOR_MOD = 95,
RANGE = 96,
RESIZE_NEAREST_NEIGHBOR = 97,
}
// Options for the builtin operators.
union BuiltinOptions {
Conv2DOptions,
DepthwiseConv2DOptions,
ConcatEmbeddingsOptions,
LSHProjectionOptions,
Pool2DOptions,
SVDFOptions,
RNNOptions,
FullyConnectedOptions,
SoftmaxOptions,
ConcatenationOptions,
AddOptions,
L2NormOptions,
LocalResponseNormalizationOptions,
LSTMOptions,
ResizeBilinearOptions,
CallOptions,
ReshapeOptions,
SkipGramOptions,
SpaceToDepthOptions,
EmbeddingLookupSparseOptions,
MulOptions,
PadOptions,
GatherOptions,
BatchToSpaceNDOptions,
SpaceToBatchNDOptions,
TransposeOptions,
ReducerOptions,
SubOptions,
DivOptions,
SqueezeOptions,
SequenceRNNOptions,
StridedSliceOptions,
ExpOptions,
TopKV2Options,
SplitOptions,
LogSoftmaxOptions,
CastOptions,
DequantizeOptions,
MaximumMinimumOptions,
ArgMaxOptions,
LessOptions,
NegOptions,
PadV2Options,
GreaterOptions,
GreaterEqualOptions,
LessEqualOptions,
SelectOptions,
SliceOptions,
TransposeConvOptions,
SparseToDenseOptions,
TileOptions,
ExpandDimsOptions,
EqualOptions,
NotEqualOptions,
ShapeOptions,
PowOptions,
ArgMinOptions,
FakeQuantOptions,
PackOptions,
LogicalOrOptions,
OneHotOptions,
LogicalAndOptions,
LogicalNotOptions,
UnpackOptions,
FloorDivOptions,
SquareOptions,
ZerosLikeOptions,
FillOptions,
BidirectionalSequenceLSTMOptions,
BidirectionalSequenceRNNOptions,
UnidirectionalSequenceLSTMOptions,
FloorModOptions,
RangeOptions,
ResizeNearestNeighborOptions,
}
enum Padding : byte { SAME, VALID }
enum ActivationFunctionType : byte {
NONE = 0,
RELU = 1,
RELU_N1_TO_1 = 2,
RELU6 = 3,
TANH = 4,
SIGN_BIT = 5,
}
table Conv2DOptions {
padding:Padding;
stride_w:int;
stride_h:int;
fused_activation_function:ActivationFunctionType;
dilation_w_factor:int = 1;
dilation_h_factor:int = 1;
}
table Pool2DOptions {
padding:Padding;
stride_w:int;
stride_h:int;
filter_width:int;
filter_height:int;
fused_activation_function:ActivationFunctionType;
}
table DepthwiseConv2DOptions {
// Parameters for DepthwiseConv version 1 or above.
padding:Padding;
stride_w:int;
stride_h:int;
depth_multiplier:int;
fused_activation_function:ActivationFunctionType;
// Parameters for DepthwiseConv version 2 or above.
dilation_w_factor:int = 1;
dilation_h_factor:int = 1;
}
table ConcatEmbeddingsOptions {
num_channels:int;
num_columns_per_channel:[int];
embedding_dim_per_channel:[int]; // This could be inferred from parameters.
}
enum LSHProjectionType: byte {
UNKNOWN = 0,
SPARSE = 1,
DENSE = 2,
}
table LSHProjectionOptions {
type: LSHProjectionType;
}
table SVDFOptions {
rank:int;
fused_activation_function:ActivationFunctionType;
}
// An implementation of TensorFlow RNNCell.
table RNNOptions {
fused_activation_function:ActivationFunctionType;
}
// An implementation of TensorFlow dynamic_rnn with RNNCell.
table SequenceRNNOptions {
time_major:bool;
fused_activation_function:ActivationFunctionType;
}
// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
table BidirectionalSequenceRNNOptions {
time_major:bool;
fused_activation_function:ActivationFunctionType;
merge_outputs: bool;
}
enum FullyConnectedOptionsWeightsFormat: byte {
DEFAULT = 0,
SHUFFLED4x16INT8 = 1,
}
// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
table FullyConnectedOptions {
// Parameters for FullyConnected version 1 or above.
fused_activation_function:ActivationFunctionType;
// Parameters for FullyConnected version 2 or above.
weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT;
}
table SoftmaxOptions {
beta: float;
}
// An implementation of TensorFlow concat.
table ConcatenationOptions {
axis:int;
fused_activation_function:ActivationFunctionType;
}
table AddOptions {
fused_activation_function:ActivationFunctionType;
}
table MulOptions {
fused_activation_function:ActivationFunctionType;
}
table L2NormOptions {
fused_activation_function:ActivationFunctionType;
}
table LocalResponseNormalizationOptions {
radius:int;
bias:float;
alpha:float;
beta:float;
}
enum LSTMKernelType : byte {
// Full LSTM kernel which supports peephole and projection.
FULL = 0,
// Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
BASIC = 1,
}
// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
table LSTMOptions {
// Parameters for LSTM version 1 or above.
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping
// Parameters for LSTM version 2 or above.
// Basic kernel is only supported in version 2 or above.
kernel_type: LSTMKernelType = FULL;
}
// An implementation of TensorFlow dynamic_rnn with LSTMCell.
table UnidirectionalSequenceLSTMOptions {
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping
// If true then first dimension is sequence, otherwise batch.
time_major:bool;
}
table BidirectionalSequenceLSTMOptions {
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping
// If true, store the outputs of both directions into the first output.
merge_outputs: bool;
}
table ResizeBilinearOptions {
new_height: int (deprecated);
new_width: int (deprecated);
align_corners: bool;
}
table ResizeNearestNeighborOptions {
align_corners: bool;
}
// A call operation options
table CallOptions {
// The subgraph index that needs to be called.
subgraph:uint;
}
table PadOptions {
}
table PadV2Options {
}
table ReshapeOptions {
new_shape:[int];
}
table SpaceToBatchNDOptions {
}
table BatchToSpaceNDOptions {
}
table SkipGramOptions {
ngram_size: int;
max_skip_size: int;
include_all_ngrams: bool;
}
table SpaceToDepthOptions {
block_size: int;
}
table SubOptions {
fused_activation_function:ActivationFunctionType;
}
table DivOptions {
fused_activation_function:ActivationFunctionType;
}
table TopKV2Options {
}
enum CombinerType : byte {
SUM = 0,
MEAN = 1,
SQRTN = 2,
}
table EmbeddingLookupSparseOptions {
combiner:CombinerType;
}
table GatherOptions {
axis: int;
}
table TransposeOptions {
}
table ExpOptions {
}
table ReducerOptions {
keep_dims: bool;
}
table SqueezeOptions {
squeeze_dims:[int];
}
table SplitOptions {
num_splits: int;
}
table StridedSliceOptions {
begin_mask: int;
end_mask: int;
ellipsis_mask: int;
new_axis_mask: int;
shrink_axis_mask: int;
}
table LogSoftmaxOptions {
}
table CastOptions {
in_data_type: TensorType;
out_data_type: TensorType;
}
table DequantizeOptions {
}
table MaximumMinimumOptions {
}
table TileOptions {
}
table ArgMaxOptions {
output_type : TensorType;
}
table ArgMinOptions {
output_type : TensorType;
}
table GreaterOptions {
}
table GreaterEqualOptions {
}
table LessOptions {
}
table LessEqualOptions {
}
table NegOptions {
}
table SelectOptions {
}
table SliceOptions {
}
table TransposeConvOptions {
padding:Padding;
stride_w:int;
stride_h:int;
}
table ExpandDimsOptions {
}
table SparseToDenseOptions {
validate_indices:bool;
}
table EqualOptions {
}
table NotEqualOptions {
}
table ShapeOptions {
// Optional output type of the operation (int32 or int64). Defaults to int32.
out_type : TensorType;
}
table PowOptions {
}
table FakeQuantOptions {
// Parameters supported by version 1:
min:float;
max:float;
num_bits:int;
// Parameters supported by version 2:
narrow_range:bool;
}
table PackOptions {
values_count:int;
axis:int;
}
table LogicalOrOptions {
}
table OneHotOptions {
axis:int;
}
table LogicalAndOptions {
}
table LogicalNotOptions {
}
table UnpackOptions {
num:int;
axis:int;
}
table FloorDivOptions {
}
table SquareOptions {
}
table ZerosLikeOptions {
}
table FillOptions {
}
table FloorModOptions {
}
table RangeOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
builtin_code:BuiltinOperator;
custom_code:string;
// The version of the operator. The version need to be bumped whenever new
// parameters are introduced into an op.
version:int = 1;
}
enum CustomOptionsFormat : byte {
FLEXBUFFERS = 0,
}
// An operator takes tensors as inputs and outputs. The type of operation being
// performed is determined by an index into the list of valid OperatorCodes,
// while the specifics of each operations is configured using builtin_options
// or custom_options.
table Operator {
// Index into the operator_codes array. Using an integer here avoids
// complicate map lookups.
opcode_index:uint;
// Optional input and output tensors are indicated by -1.
inputs:[int];
outputs:[int];
builtin_options:BuiltinOptions;
custom_options:[ubyte];
custom_options_format:CustomOptionsFormat;
// A list of booleans indicating the input tensors which are being mutated by
// this operator.(e.g. used by RNN and LSTM).
// For example, if the "inputs" array refers to 5 tensors and the second and
// fifth are mutable variables, then this list will contain
// [false, true, false, false, true].
//
// If the list is empty, no variable is mutated in this operator.
// The list either has the same length as `inputs`, or is empty.
mutating_variable_inputs:[bool];
}
// The root type, defining a subgraph, which typically represents an entire
// model.
table SubGraph {
// A list of all tensors used in this subgraph.
tensors:[Tensor];
// Indices of the tensors that are inputs into this subgraph. Note this is
// the list of non-static tensors that feed into the subgraph for inference.
inputs:[int];
// Indices of the tensors that are outputs out of this subgraph. Note this is
// the list of output tensors that are considered the product of the
// subgraph's inference.
outputs:[int];
// All operators, in execution order.
operators:[Operator];
// Name of this subgraph (used for debugging).
name:string;
}
// Table of raw data buffers (used for constant tensors). Referenced by tensors
// by index. The generous alignment accommodates mmap-friendly data structures.
table Buffer {
data:[ubyte] (force_align: 16);
}
table Model {
// Version of the schema.
version:uint;
// A list of all operator codes used in this model. This is
// kept in order because operators carry an index into this
// vector.
operator_codes:[OperatorCode];
// All the subgraphs of the model. The 0th is assumed to be the main
// model.
subgraphs:[SubGraph];
// A description of the model.
description:string;
// Buffers of the model.
// Note the 0th entry of this array must be an empty buffer (sentinel).
// This is a convention so that tensors without a buffer can provide 0 as
// their buffer.
buffers:[Buffer];
// Metadata about the model. Indirects into the existings buffers list.
metadata_buffer:[int];
}
root_type Model;