From 058717dd41ac7410120da47b05706eabcbf3c5ec Mon Sep 17 00:00:00 2001
From: Suharsh Sivakumar <suharshs@google.com>
Date: Mon, 12 Nov 2018 09:28:10 -0800
Subject: [PATCH] QuantizationDetails and CustomQuantization to schema.

Allows experimenting with new quantization techniques.

PiperOrigin-RevId: 221104574
---
 tensorflow/lite/schema/schema.fbs         |  24 +-
 tensorflow/lite/schema/schema_generated.h | 300 +++++++++++++++++++++-
 2 files changed, 313 insertions(+), 11 deletions(-)

diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index 07b16f28121..9b0eae74c3b 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -39,16 +39,34 @@ enum TensorType : byte {
   BOOL = 6,
   INT16 = 7,
   COMPLEX64 = 8,
+  INT8 = 9,
 }
 
-// Parameters for converting a quantized tensor back to float. Given a
-// quantized value q, the corresponding float value f should be:
-//   f = scale * (q - zero_point)
+// Custom quantization parameters for experimenting with new quantization
+// techniques.
+table CustomQuantization {
+  custom:[byte];
+}
+
+// Represents a specific quantization technique's parameters.
+union QuantizationDetails {
+  CustomQuantization,
+}
+
+// Parameters for converting a quantized tensor back to float.
 table QuantizationParameters {
+  // These four parameters are the asymmetric linear quantization parameters.
+  // Given a quantized value q, the corresponding float value f should be:
+  //   f = scale * (q - zero_point)
+  // For other quantization types, the QuantizationDetails below is used.
   min:[float];  // For importing back into tensorflow.
   max:[float];  // For importing back into tensorflow.
   scale:[float];  // For dequantizing the tensor's values.
   zero_point:[long];
+
+  // If this is not none, the quantization parameters above are ignored and the
+  // value of the QuantizationDetails union below should be used.
+  details:QuantizationDetails;
 }
 
 table Tensor {
diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h
index 479c0d658ba..b7885cfcc50 100755
--- a/tensorflow/lite/schema/schema_generated.h
+++ b/tensorflow/lite/schema/schema_generated.h
@@ -22,6 +22,9 @@ limitations under the License.
 
 namespace tflite {
 
+struct CustomQuantization;
+struct CustomQuantizationT;
+
 struct QuantizationParameters;
 struct QuantizationParametersT;
 
@@ -275,11 +278,12 @@ enum TensorType {
   TensorType_BOOL = 6,
   TensorType_INT16 = 7,
   TensorType_COMPLEX64 = 8,
+  TensorType_INT8 = 9,
   TensorType_MIN = TensorType_FLOAT32,
-  TensorType_MAX = TensorType_COMPLEX64
+  TensorType_MAX = TensorType_INT8
 };
 
-inline const TensorType (&EnumValuesTensorType())[9] {
+inline const TensorType (&EnumValuesTensorType())[10] {
   static const TensorType values[] = {
     TensorType_FLOAT32,
     TensorType_FLOAT16,
@@ -289,7 +293,8 @@ inline const TensorType (&EnumValuesTensorType())[9] {
     TensorType_STRING,
     TensorType_BOOL,
     TensorType_INT16,
-    TensorType_COMPLEX64
+    TensorType_COMPLEX64,
+    TensorType_INT8
   };
   return values;
 }
@@ -305,6 +310,7 @@ inline const char * const *EnumNamesTensorType() {
     "BOOL",
     "INT16",
     "COMPLEX64",
+    "INT8",
     nullptr
   };
   return names;
@@ -315,6 +321,87 @@ inline const char *EnumNameTensorType(TensorType e) {
   return EnumNamesTensorType()[index];
 }
 
+enum QuantizationDetails {
+  QuantizationDetails_NONE = 0,
+  QuantizationDetails_CustomQuantization = 1,
+  QuantizationDetails_MIN = QuantizationDetails_NONE,
+  QuantizationDetails_MAX = QuantizationDetails_CustomQuantization
+};
+
+inline const QuantizationDetails (&EnumValuesQuantizationDetails())[2] {
+  static const QuantizationDetails values[] = {
+    QuantizationDetails_NONE,
+    QuantizationDetails_CustomQuantization
+  };
+  return values;
+}
+
+inline const char * const *EnumNamesQuantizationDetails() {
+  static const char * const names[] = {
+    "NONE",
+    "CustomQuantization",
+    nullptr
+  };
+  return names;
+}
+
+inline const char *EnumNameQuantizationDetails(QuantizationDetails e) {
+  const size_t index = static_cast<int>(e);
+  return EnumNamesQuantizationDetails()[index];
+}
+
+template<typename T> struct QuantizationDetailsTraits {
+  static const QuantizationDetails enum_value = QuantizationDetails_NONE;
+};
+
+template<> struct QuantizationDetailsTraits<CustomQuantization> {
+  static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization;
+};
+
+struct QuantizationDetailsUnion {
+  QuantizationDetails type;
+  void *value;
+
+  QuantizationDetailsUnion() : type(QuantizationDetails_NONE), value(nullptr) {}
+  QuantizationDetailsUnion(QuantizationDetailsUnion&& u) FLATBUFFERS_NOEXCEPT :
+    type(QuantizationDetails_NONE), value(nullptr)
+    { std::swap(type, u.type); std::swap(value, u.value); }
+  QuantizationDetailsUnion(const QuantizationDetailsUnion &) FLATBUFFERS_NOEXCEPT;
+  QuantizationDetailsUnion &operator=(const QuantizationDetailsUnion &u) FLATBUFFERS_NOEXCEPT
+    { QuantizationDetailsUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; }
+  QuantizationDetailsUnion &operator=(QuantizationDetailsUnion &&u) FLATBUFFERS_NOEXCEPT
+    { std::swap(type, u.type); std::swap(value, u.value); return *this; }
+  ~QuantizationDetailsUnion() { Reset(); }
+
+  void Reset();
+
+#ifndef FLATBUFFERS_CPP98_STL
+  template <typename T>
+  void Set(T&& val) {
+    Reset();
+    type = QuantizationDetailsTraits<typename T::TableType>::enum_value;
+    if (type != QuantizationDetails_NONE) {
+      value = new T(std::forward<T>(val));
+    }
+  }
+#endif  // FLATBUFFERS_CPP98_STL
+
+  static void *UnPack(const void *obj, QuantizationDetails type, const flatbuffers::resolver_function_t *resolver);
+  flatbuffers::Offset<void> Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const;
+
+  CustomQuantizationT *AsCustomQuantization() {
+    return type == QuantizationDetails_CustomQuantization ?
+      reinterpret_cast<CustomQuantizationT *>(value) : nullptr;
+  }
+  const CustomQuantizationT *AsCustomQuantization() const {
+    return type == QuantizationDetails_CustomQuantization ?
+      reinterpret_cast<const CustomQuantizationT *>(value) : nullptr;
+  }
+};
+
+bool VerifyQuantizationDetails(flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type);
+bool VerifyQuantizationDetailsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types);
+
 enum BuiltinOperator {
   BuiltinOperator_ADD = 0,
   BuiltinOperator_AVERAGE_POOL_2D = 1,
@@ -2024,12 +2111,75 @@ inline const char *EnumNameCustomOptionsFormat(CustomOptionsFormat e) {
   return EnumNamesCustomOptionsFormat()[index];
 }
 
+struct CustomQuantizationT : public flatbuffers::NativeTable {
+  typedef CustomQuantization TableType;
+  std::vector<int8_t> custom;
+  CustomQuantizationT() {
+  }
+};
+
+struct CustomQuantization FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef CustomQuantizationT NativeTableType;
+  enum {
+    VT_CUSTOM = 4
+  };
+  const flatbuffers::Vector<int8_t> *custom() const {
+    return GetPointer<const flatbuffers::Vector<int8_t> *>(VT_CUSTOM);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           VerifyOffset(verifier, VT_CUSTOM) &&
+           verifier.VerifyVector(custom()) &&
+           verifier.EndTable();
+  }
+  CustomQuantizationT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(CustomQuantizationT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<CustomQuantization> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct CustomQuantizationBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_custom(flatbuffers::Offset<flatbuffers::Vector<int8_t>> custom) {
+    fbb_.AddOffset(CustomQuantization::VT_CUSTOM, custom);
+  }
+  explicit CustomQuantizationBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  CustomQuantizationBuilder &operator=(const CustomQuantizationBuilder &);
+  flatbuffers::Offset<CustomQuantization> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<CustomQuantization>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<CustomQuantization> CreateCustomQuantization(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<flatbuffers::Vector<int8_t>> custom = 0) {
+  CustomQuantizationBuilder builder_(_fbb);
+  builder_.add_custom(custom);
+  return builder_.Finish();
+}
+
+inline flatbuffers::Offset<CustomQuantization> CreateCustomQuantizationDirect(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<int8_t> *custom = nullptr) {
+  return tflite::CreateCustomQuantization(
+      _fbb,
+      custom ? _fbb.CreateVector<int8_t>(*custom) : 0);
+}
+
+flatbuffers::Offset<CustomQuantization> CreateCustomQuantization(flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
 struct QuantizationParametersT : public flatbuffers::NativeTable {
   typedef QuantizationParameters TableType;
   std::vector<float> min;
   std::vector<float> max;
   std::vector<float> scale;
   std::vector<int64_t> zero_point;
+  QuantizationDetailsUnion details;
   QuantizationParametersT() {
   }
 };
@@ -2040,7 +2190,9 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
     VT_MIN = 4,
     VT_MAX = 6,
     VT_SCALE = 8,
-    VT_ZERO_POINT = 10
+    VT_ZERO_POINT = 10,
+    VT_DETAILS_TYPE = 12,
+    VT_DETAILS = 14
   };
   const flatbuffers::Vector<float> *min() const {
     return GetPointer<const flatbuffers::Vector<float> *>(VT_MIN);
@@ -2054,6 +2206,16 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
   const flatbuffers::Vector<int64_t> *zero_point() const {
     return GetPointer<const flatbuffers::Vector<int64_t> *>(VT_ZERO_POINT);
   }
+  QuantizationDetails details_type() const {
+    return static_cast<QuantizationDetails>(GetField<uint8_t>(VT_DETAILS_TYPE, 0));
+  }
+  const void *details() const {
+    return GetPointer<const void *>(VT_DETAILS);
+  }
+  template<typename T> const T *details_as() const;
+  const CustomQuantization *details_as_CustomQuantization() const {
+    return details_type() == QuantizationDetails_CustomQuantization ? static_cast<const CustomQuantization *>(details()) : nullptr;
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyOffset(verifier, VT_MIN) &&
@@ -2064,6 +2226,9 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
            verifier.VerifyVector(scale()) &&
            VerifyOffset(verifier, VT_ZERO_POINT) &&
            verifier.VerifyVector(zero_point()) &&
+           VerifyField<uint8_t>(verifier, VT_DETAILS_TYPE) &&
+           VerifyOffset(verifier, VT_DETAILS) &&
+           VerifyQuantizationDetails(verifier, details(), details_type()) &&
            verifier.EndTable();
   }
   QuantizationParametersT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2071,6 +2236,10 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
   static flatbuffers::Offset<QuantizationParameters> Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 };
 
+template<> inline const CustomQuantization *QuantizationParameters::details_as<CustomQuantization>() const {
+  return details_as_CustomQuantization();
+}
+
 struct QuantizationParametersBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
@@ -2086,6 +2255,12 @@ struct QuantizationParametersBuilder {
   void add_zero_point(flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point) {
     fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point);
   }
+  void add_details_type(QuantizationDetails details_type) {
+    fbb_.AddElement<uint8_t>(QuantizationParameters::VT_DETAILS_TYPE, static_cast<uint8_t>(details_type), 0);
+  }
+  void add_details(flatbuffers::Offset<void> details) {
+    fbb_.AddOffset(QuantizationParameters::VT_DETAILS, details);
+  }
   explicit QuantizationParametersBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -2103,12 +2278,16 @@ inline flatbuffers::Offset<QuantizationParameters> CreateQuantizationParameters(
     flatbuffers::Offset<flatbuffers::Vector<float>> min = 0,
     flatbuffers::Offset<flatbuffers::Vector<float>> max = 0,
     flatbuffers::Offset<flatbuffers::Vector<float>> scale = 0,
-    flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point = 0) {
+    flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point = 0,
+    QuantizationDetails details_type = QuantizationDetails_NONE,
+    flatbuffers::Offset<void> details = 0) {
   QuantizationParametersBuilder builder_(_fbb);
+  builder_.add_details(details);
   builder_.add_zero_point(zero_point);
   builder_.add_scale(scale);
   builder_.add_max(max);
   builder_.add_min(min);
+  builder_.add_details_type(details_type);
   return builder_.Finish();
 }
 
@@ -2117,13 +2296,17 @@ inline flatbuffers::Offset<QuantizationParameters> CreateQuantizationParametersD
     const std::vector<float> *min = nullptr,
     const std::vector<float> *max = nullptr,
     const std::vector<float> *scale = nullptr,
-    const std::vector<int64_t> *zero_point = nullptr) {
+    const std::vector<int64_t> *zero_point = nullptr,
+    QuantizationDetails details_type = QuantizationDetails_NONE,
+    flatbuffers::Offset<void> details = 0) {
   return tflite::CreateQuantizationParameters(
       _fbb,
       min ? _fbb.CreateVector<float>(*min) : 0,
       max ? _fbb.CreateVector<float>(*max) : 0,
       scale ? _fbb.CreateVector<float>(*scale) : 0,
-      zero_point ? _fbb.CreateVector<int64_t>(*zero_point) : 0);
+      zero_point ? _fbb.CreateVector<int64_t>(*zero_point) : 0,
+      details_type,
+      details);
 }
 
 flatbuffers::Offset<QuantizationParameters> CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -7534,6 +7717,32 @@ inline flatbuffers::Offset<Model> CreateModelDirect(
 
 flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
+inline CustomQuantizationT *CustomQuantization::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new CustomQuantizationT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void CustomQuantization::UnPackTo(CustomQuantizationT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  { auto _e = custom(); if (_e) { _o->custom.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom[_i] = _e->Get(_i); } } };
+}
+
+inline flatbuffers::Offset<CustomQuantization> CustomQuantization::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateCustomQuantization(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<CustomQuantization> CreateCustomQuantization(flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CustomQuantizationT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  auto _custom = _o->custom.size() ? _fbb.CreateVector(_o->custom) : 0;
+  return tflite::CreateCustomQuantization(
+      _fbb,
+      _custom);
+}
+
 inline QuantizationParametersT *QuantizationParameters::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
   auto _o = new QuantizationParametersT();
   UnPackTo(_o, _resolver);
@@ -7547,6 +7756,8 @@ inline void QuantizationParameters::UnPackTo(QuantizationParametersT *_o, const
   { auto _e = max(); if (_e) { _o->max.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->max[_i] = _e->Get(_i); } } };
   { auto _e = scale(); if (_e) { _o->scale.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scale[_i] = _e->Get(_i); } } };
   { auto _e = zero_point(); if (_e) { _o->zero_point.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zero_point[_i] = _e->Get(_i); } } };
+  { auto _e = details_type(); _o->details.type = _e; };
+  { auto _e = details(); if (_e) _o->details.value = QuantizationDetailsUnion::UnPack(_e, details_type(), _resolver); };
 }
 
 inline flatbuffers::Offset<QuantizationParameters> QuantizationParameters::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7561,12 +7772,16 @@ inline flatbuffers::Offset<QuantizationParameters> CreateQuantizationParameters(
   auto _max = _o->max.size() ? _fbb.CreateVector(_o->max) : 0;
   auto _scale = _o->scale.size() ? _fbb.CreateVector(_o->scale) : 0;
   auto _zero_point = _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0;
+  auto _details_type = _o->details.type;
+  auto _details = _o->details.Pack(_fbb);
   return tflite::CreateQuantizationParameters(
       _fbb,
       _min,
       _max,
       _scale,
-      _zero_point);
+      _zero_point,
+      _details_type,
+      _details);
 }
 
 inline TensorT *Tensor::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -9775,6 +9990,75 @@ inline flatbuffers::Offset<Model> CreateModel(flatbuffers::FlatBufferBuilder &_f
       _metadata_buffer);
 }
 
+inline bool VerifyQuantizationDetails(flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type) {
+  switch (type) {
+    case QuantizationDetails_NONE: {
+      return true;
+    }
+    case QuantizationDetails_CustomQuantization: {
+      auto ptr = reinterpret_cast<const CustomQuantization *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    default: return false;
+  }
+}
+
+inline bool VerifyQuantizationDetailsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) {
+  if (!values || !types) return !values && !types;
+  if (values->size() != types->size()) return false;
+  for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
+    if (!VerifyQuantizationDetails(
+        verifier,  values->Get(i), types->GetEnum<QuantizationDetails>(i))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline void *QuantizationDetailsUnion::UnPack(const void *obj, QuantizationDetails type, const flatbuffers::resolver_function_t *resolver) {
+  switch (type) {
+    case QuantizationDetails_CustomQuantization: {
+      auto ptr = reinterpret_cast<const CustomQuantization *>(obj);
+      return ptr->UnPack(resolver);
+    }
+    default: return nullptr;
+  }
+}
+
+inline flatbuffers::Offset<void> QuantizationDetailsUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const {
+  switch (type) {
+    case QuantizationDetails_CustomQuantization: {
+      auto ptr = reinterpret_cast<const CustomQuantizationT *>(value);
+      return CreateCustomQuantization(_fbb, ptr, _rehasher).Union();
+    }
+    default: return 0;
+  }
+}
+
+inline QuantizationDetailsUnion::QuantizationDetailsUnion(const QuantizationDetailsUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) {
+  switch (type) {
+    case QuantizationDetails_CustomQuantization: {
+      value = new CustomQuantizationT(*reinterpret_cast<CustomQuantizationT *>(u.value));
+      break;
+    }
+    default:
+      break;
+  }
+}
+
+inline void QuantizationDetailsUnion::Reset() {
+  switch (type) {
+    case QuantizationDetails_CustomQuantization: {
+      auto ptr = reinterpret_cast<CustomQuantizationT *>(value);
+      delete ptr;
+      break;
+    }
+    default: break;
+  }
+  value = nullptr;
+  type = QuantizationDetails_NONE;
+}
+
 inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type) {
   switch (type) {
     case BuiltinOptions_NONE: {