Implement TFLite Quantize operation.
PiperOrigin-RevId: 239312773
This commit is contained in:
parent
02f54c9a54
commit
6ee35631af
@ -139,6 +139,7 @@ typedef enum {
|
||||
kTfLiteBuiltinElu = 111,
|
||||
kTfLiteBuiltinReverseSequence = 112,
|
||||
kTfLiteBuiltinMatrixDiag = 113,
|
||||
kTfLiteBuiltinQuantize = 114,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -744,6 +744,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_GATHER_ND:
|
||||
case BuiltinOperator_WHERE:
|
||||
case BuiltinOperator_RANK:
|
||||
case BuiltinOperator_QUANTIZE:
|
||||
break;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
@ -199,6 +199,7 @@ cc_library(
|
||||
"pad.cc",
|
||||
"pooling.cc",
|
||||
"pow.cc",
|
||||
"quantize.cc",
|
||||
"range.cc",
|
||||
"rank.cc",
|
||||
"reduce.cc",
|
||||
@ -240,6 +241,7 @@ cc_library(
|
||||
deps = [
|
||||
":activation_functor",
|
||||
":eigen_support",
|
||||
":gemm_support",
|
||||
":kernel_util",
|
||||
":lstm_eval",
|
||||
":op_macros",
|
||||
@ -247,7 +249,6 @@ cc_library(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:gemm_support",
|
||||
"//tensorflow/lite/kernels/internal:audio_utils",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/kernels/internal:optimized",
|
||||
@ -1401,3 +1402,16 @@ cc_test(
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "quantize_test",
|
||||
size = "small",
|
||||
srcs = ["quantize_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -47,6 +47,7 @@ namespace tflite {
|
||||
namespace optimized_ops {
|
||||
|
||||
// Unoptimized reference ops:
|
||||
using reference_ops::AffineQuantize;
|
||||
using reference_ops::ArgMax;
|
||||
using reference_ops::ArgMinMax;
|
||||
using reference_ops::Broadcast4DSlowGreater;
|
||||
|
@ -2878,6 +2878,26 @@ inline void Dequantize(const tflite::DequantizationParams& op_params,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void AffineQuantize(const tflite::QuantizationParams& op_params,
|
||||
const RuntimeShape& input_shape,
|
||||
const float* input_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Quantize");
|
||||
const int32 zero_point = op_params.zero_point;
|
||||
const double scale = static_cast<double>(op_params.scale);
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
static constexpr int32 min_val = std::numeric_limits<T>::min();
|
||||
static constexpr int32 max_val = std::numeric_limits<T>::max();
|
||||
|
||||
for (int i = 0; i < flat_size; i++) {
|
||||
const float val = input_data[i];
|
||||
int32 unclamped = static_cast<int32>(TfLiteRound(val / scale)) + zero_point;
|
||||
int32 clamped = std::min(std::max(unclamped, min_val), max_val);
|
||||
output_data[i] = clamped;
|
||||
}
|
||||
}
|
||||
|
||||
inline void FakeQuant(const tflite::FakeQuantParams& op_params,
|
||||
const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
|
100
tensorflow/lite/kernels/quantize.cc
Normal file
100
tensorflow/lite/kernels/quantize.cc
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace quantize {
|
||||
|
||||
struct OpContext {
|
||||
OpContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
input = GetInput(context, node, 0);
|
||||
output = GetOutput(context, node, 0);
|
||||
}
|
||||
const TfLiteTensor* input;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
OpContext op_context(context, node);
|
||||
|
||||
TF_LITE_ENSURE(context, op_context.input->type == kTfLiteFloat32);
|
||||
TF_LITE_ENSURE(context, op_context.output->type == kTfLiteUInt8 ||
|
||||
op_context.output->type == kTfLiteInt8);
|
||||
|
||||
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
|
||||
// Currently this only support affine per-layer quantization.
|
||||
TF_LITE_ENSURE_EQ(context, op_context.output->quantization.type,
|
||||
kTfLiteAffineQuantization);
|
||||
const auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
op_context.output->quantization.params);
|
||||
TF_LITE_ENSURE(context, affine_quantization);
|
||||
TF_LITE_ENSURE(context, affine_quantization->scale);
|
||||
TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
|
||||
|
||||
return context->ResizeTensor(context, op_context.output,
|
||||
TfLiteIntArrayCopy(op_context.input->dims));
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpContext op_context(context, node);
|
||||
|
||||
tflite::QuantizationParams op_params;
|
||||
op_params.zero_point = op_context.output->params.zero_point;
|
||||
op_params.scale = op_context.output->params.scale;
|
||||
switch (op_context.output->type) {
|
||||
case kTfLiteUInt8:
|
||||
optimized_ops::AffineQuantize(op_params, GetTensorShape(op_context.input),
|
||||
GetTensorData<float>(op_context.input),
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<uint8_t>(op_context.output));
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
optimized_ops::AffineQuantize(op_params, GetTensorShape(op_context.input),
|
||||
GetTensorData<float>(op_context.input),
|
||||
GetTensorShape(op_context.output),
|
||||
GetTensorData<int8_t>(op_context.output));
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context, "Type %d not supported.",
|
||||
op_context.input->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace quantize
|
||||
|
||||
TfLiteRegistration* Register_QUANTIZE_OPT() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, quantize::Prepare,
|
||||
quantize::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_QUANTIZE() { return Register_QUANTIZE_OPT(); }
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
85
tensorflow/lite/kernels/quantize_test.cc
Normal file
85
tensorflow/lite/kernels/quantize_test.cc
Normal file
@ -0,0 +1,85 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class QuantizeOpModel : public SingleOpModel {
|
||||
public:
|
||||
QuantizeOpModel(TensorType type, std::initializer_list<int> shape,
|
||||
float scale, int32_t zero_point) {
|
||||
const TensorData output_tensor_data = {type, shape, 0,
|
||||
0, scale, zero_point};
|
||||
input_ = AddInput({TensorType_FLOAT32, shape});
|
||||
output_ = AddOutput(output_tensor_data);
|
||||
SetBuiltinOp(BuiltinOperator_QUANTIZE, BuiltinOptions_QuantizeOptions,
|
||||
CreateQuantizeOptions(builder_).Union());
|
||||
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(QuantizeOpTest, UINT8) {
|
||||
// [-63.5, 64] -> scale=0.5 zero_point=127 for UINT8
|
||||
QuantizeOpModel m(TensorType_UINT8, {2, 5}, 0.5, 127);
|
||||
|
||||
m.SetInput({-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}));
|
||||
}
|
||||
|
||||
TEST(QuantizeOpTest, INT8) {
|
||||
// [-63.5, 64] -> scale=0.5, zero_point=1 for INT8
|
||||
QuantizeOpModel m(TensorType_INT8, {2, 5}, 0.5, -1);
|
||||
|
||||
m.SetInput({-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||
ElementsAreArray(
|
||||
{-128, -127, -126, -125, -124, 123, 124, 125, 126, 127}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -140,6 +140,7 @@ TfLiteRegistration* Register_WHERE();
|
||||
TfLiteRegistration* Register_ELU();
|
||||
TfLiteRegistration* Register_REVERSE_SEQUENCE();
|
||||
TfLiteRegistration* Register_MATRIX_DIAG();
|
||||
TfLiteRegistration* Register_QUANTIZE();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -376,6 +377,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
||||
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
|
||||
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -672,6 +672,7 @@ TfLiteStatus AddOpsAndParams(
|
||||
case tflite::BuiltinOperator_ELU:
|
||||
case tflite::BuiltinOperator_REVERSE_SEQUENCE:
|
||||
case tflite::BuiltinOperator_MATRIX_DIAG:
|
||||
case tflite::BuiltinOperator_QUANTIZE:
|
||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||
return kTfLiteError;
|
||||
break;
|
||||
|
@ -227,6 +227,7 @@ enum BuiltinOperator : byte {
|
||||
ELU = 111,
|
||||
REVERSE_SEQUENCE = 112,
|
||||
MATRIX_DIAG = 113,
|
||||
QUANTIZE = 114,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -319,6 +320,7 @@ union BuiltinOptions {
|
||||
RankOptions,
|
||||
ReverseSequenceOptions,
|
||||
MatrixDiagOptions,
|
||||
QuantizeOptions,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -762,6 +764,9 @@ table ReverseSequenceOptions {
|
||||
table MatrixDiagOptions {
|
||||
}
|
||||
|
||||
table QuantizeOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -295,6 +295,9 @@ struct ReverseSequenceOptionsT;
|
||||
struct MatrixDiagOptions;
|
||||
struct MatrixDiagOptionsT;
|
||||
|
||||
struct QuantizeOptions;
|
||||
struct QuantizeOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -558,11 +561,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_ELU = 111,
|
||||
BuiltinOperator_REVERSE_SEQUENCE = 112,
|
||||
BuiltinOperator_MATRIX_DIAG = 113,
|
||||
BuiltinOperator_QUANTIZE = 114,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_MATRIX_DIAG
|
||||
BuiltinOperator_MAX = BuiltinOperator_QUANTIZE
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[114] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -676,7 +680,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[113] {
|
||||
BuiltinOperator_RANK,
|
||||
BuiltinOperator_ELU,
|
||||
BuiltinOperator_REVERSE_SEQUENCE,
|
||||
BuiltinOperator_MATRIX_DIAG
|
||||
BuiltinOperator_MATRIX_DIAG,
|
||||
BuiltinOperator_QUANTIZE
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -797,6 +802,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"ELU",
|
||||
"REVERSE_SEQUENCE",
|
||||
"MATRIX_DIAG",
|
||||
"QUANTIZE",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -897,11 +903,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_RankOptions = 86,
|
||||
BuiltinOptions_ReverseSequenceOptions = 87,
|
||||
BuiltinOptions_MatrixDiagOptions = 88,
|
||||
BuiltinOptions_QuantizeOptions = 89,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_MatrixDiagOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_QuantizeOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[90] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -991,7 +998,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[89] {
|
||||
BuiltinOptions_WhereOptions,
|
||||
BuiltinOptions_RankOptions,
|
||||
BuiltinOptions_ReverseSequenceOptions,
|
||||
BuiltinOptions_MatrixDiagOptions
|
||||
BuiltinOptions_MatrixDiagOptions,
|
||||
BuiltinOptions_QuantizeOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1087,6 +1095,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"RankOptions",
|
||||
"ReverseSequenceOptions",
|
||||
"MatrixDiagOptions",
|
||||
"QuantizeOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -1453,6 +1462,10 @@ template<> struct BuiltinOptionsTraits<MatrixDiagOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<QuantizeOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2188,6 +2201,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_MatrixDiagOptions ?
|
||||
reinterpret_cast<const MatrixDiagOptionsT *>(value) : nullptr;
|
||||
}
|
||||
QuantizeOptionsT *AsQuantizeOptions() {
|
||||
return type == BuiltinOptions_QuantizeOptions ?
|
||||
reinterpret_cast<QuantizeOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const QuantizeOptionsT *AsQuantizeOptions() const {
|
||||
return type == BuiltinOptions_QuantizeOptions ?
|
||||
reinterpret_cast<const QuantizeOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -7630,6 +7651,46 @@ inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(
|
||||
|
||||
flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct QuantizeOptionsT : public flatbuffers::NativeTable {
|
||||
typedef QuantizeOptions TableType;
|
||||
QuantizeOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct QuantizeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef QuantizeOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
QuantizeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(QuantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<QuantizeOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct QuantizeOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit QuantizeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
QuantizeOptionsBuilder &operator=(const QuantizeOptionsBuilder &);
|
||||
flatbuffers::Offset<QuantizeOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<QuantizeOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
QuantizeOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -8027,6 +8088,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_MatrixDiagOptions ? static_cast<const MatrixDiagOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const QuantizeOptions *builtin_options_as_QuantizeOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_QuantizeOptions ? static_cast<const QuantizeOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -8410,6 +8474,10 @@ template<> inline const MatrixDiagOptions *Operator::builtin_options_as<MatrixDi
|
||||
return builtin_options_as_MatrixDiagOptions();
|
||||
}
|
||||
|
||||
template<> inline const QuantizeOptions *Operator::builtin_options_as<QuantizeOptions>() const {
|
||||
return builtin_options_as_QuantizeOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -11247,6 +11315,29 @@ inline flatbuffers::Offset<MatrixDiagOptions> CreateMatrixDiagOptions(flatbuffer
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline QuantizeOptionsT *QuantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new QuantizeOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void QuantizeOptions::UnPackTo(QuantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<QuantizeOptions> QuantizeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateQuantizeOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<QuantizeOptions> CreateQuantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const QuantizeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateQuantizeOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -11857,6 +11948,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_QuantizeOptions: {
|
||||
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -12227,6 +12322,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const MatrixDiagOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_QuantizeOptions: {
|
||||
auto ptr = reinterpret_cast<const QuantizeOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -12585,6 +12684,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const MatrixDiagOptionsT *>(value);
|
||||
return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_QuantizeOptions: {
|
||||
auto ptr = reinterpret_cast<const QuantizeOptionsT *>(value);
|
||||
return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -12943,6 +13046,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new MatrixDiagOptionsT(*reinterpret_cast<MatrixDiagOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_QuantizeOptions: {
|
||||
value = new QuantizeOptionsT(*reinterpret_cast<QuantizeOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -13390,6 +13497,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_QuantizeOptions: {
|
||||
auto ptr = reinterpret_cast<QuantizeOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user