diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 9a3eddcfcb4..d0c765edaca 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -632,10 +632,17 @@ class SingleOpModel { dims_count); for (int i = 0; i < dims_count; i++) { const int metadata_idx = 2 * i; + auto array_segments = + CreateInt32Vector(builder_, + builder_.CreateVector(dim_metadata[metadata_idx])) + .Union(); + auto array_indices = + CreateInt32Vector( + builder_, builder_.CreateVector(dim_metadata[metadata_idx + 1])) + .Union(); fb_dim_metadata[i] = CreateDimensionMetadata( - builder_, DimensionType_SPARSE_CSR, 0, - builder_.CreateVector(dim_metadata[metadata_idx]), - builder_.CreateVector(dim_metadata[metadata_idx + 1])); + builder_, DimensionType_SPARSE_CSR, 0, SparseIndexVector_Int32Vector, + array_segments, SparseIndexVector_Int32Vector, array_indices); } flatbuffers::Offset<SparsityParameters> s_param = CreateSparsityParameters( diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc index 22a4cf21213..8f470713e1b 100644 --- a/tensorflow/lite/model.cc +++ b/tensorflow/lite/model.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" @@ -40,6 +41,45 @@ namespace { ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { return e ? e : DefaultErrorReporter(); } + +template <typename T> +void Copy(const T* data_ptr, TfLiteIntArray** arr) { + int size = data_ptr->values()->size(); + *arr = TfLiteIntArrayCreate(size); + for (int i = 0; i < size; i++) { + (*arr)->data[i] = static_cast<int>(data_ptr->values()->Get(i)); + } +} + +void ParseSparseIndexVector(const DimensionMetadata* src, + TfLiteDimensionMetadata* tgt) { + switch (src->array_segments_type()) { + case SparseIndexVector_Int32Vector: + Copy(src->array_segments_as_Int32Vector(), &tgt->array_segments); + break; + case SparseIndexVector_Uint16Vector: + Copy(src->array_segments_as_Uint16Vector(), &tgt->array_segments); + break; + case SparseIndexVector_Uint8Vector: + Copy(src->array_segments_as_Uint8Vector(), &tgt->array_segments); + break; + default: + break; + } + switch (src->array_indices_type()) { + case SparseIndexVector_Int32Vector: + Copy(src->array_indices_as_Int32Vector(), &tgt->array_indices); + break; + case SparseIndexVector_Uint16Vector: + Copy(src->array_indices_as_Uint16Vector(), &tgt->array_indices); + break; + case SparseIndexVector_Uint8Vector: + Copy(src->array_indices_as_Uint8Vector(), &tgt->array_indices); + break; + default: + break; + } +} } // namespace const char* kEmptyTensorName = ""; @@ -422,8 +462,6 @@ TfLiteStatus InterpreterBuilder::ParseQuantization( return kTfLiteOk; } -// TODO(b/145614687): Add sparse tensor verification check in -// lite/tools/verifier.cc. TfLiteStatus InterpreterBuilder::ParseSparsity( const SparsityParameters* src_sparsity, TfLiteSparsity** sparsity_ptr) { if (!src_sparsity) { @@ -492,18 +530,7 @@ TfLiteStatus InterpreterBuilder::ParseSparsity( if (tgt_metadata->format == kTfLiteDimDense) { tgt_metadata->dense_size = src_metadata->dense_size(); } else { - const int array_segments_size = src_metadata->array_segments()->size(); - tgt_metadata->array_segments = TfLiteIntArrayCreate(array_segments_size); - for (int j = 0; j < array_segments_size; j++) { - tgt_metadata->array_segments->data[j] = - src_metadata->array_segments()->Get(j); - } - const int array_indices_size = src_metadata->array_indices()->size(); - tgt_metadata->array_indices = TfLiteIntArrayCreate(array_indices_size); - for (int j = 0; j < array_indices_size; j++) { - tgt_metadata->array_indices->data[j] = - src_metadata->array_indices()->Get(j); - } + ParseSparseIndexVector(src_metadata, tgt_metadata); } } diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index e7d5eaed29f..0553e293f6a 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -108,6 +108,27 @@ enum DimensionType : byte { SPARSE_CSR = 1, } +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + table DimensionMetadata { // Whether a dimension is dense or sparse. format:DimensionType; @@ -123,8 +144,8 @@ table DimensionMetadata { // format, where the first array is row pointers and the second array is // column indices). dense_size:int; - array_segments:[int]; - array_indices:[int]; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; } // Parameters to encode a sparse TfLite tensor. diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index b91a2f0343d..282433d7ccc 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -28,6 +28,15 @@ struct CustomQuantizationT; struct QuantizationParameters; struct QuantizationParametersT; +struct Int32Vector; +struct Int32VectorT; + +struct Uint16Vector; +struct Uint16VectorT; + +struct Uint8Vector; +struct Uint8VectorT; + struct DimensionMetadata; struct DimensionMetadataT; @@ -522,6 +531,119 @@ inline const char *EnumNameDimensionType(DimensionType e) { return EnumNamesDimensionType()[index]; } +enum SparseIndexVector { + SparseIndexVector_NONE = 0, + SparseIndexVector_Int32Vector = 1, + SparseIndexVector_Uint16Vector = 2, + SparseIndexVector_Uint8Vector = 3, + SparseIndexVector_MIN = SparseIndexVector_NONE, + SparseIndexVector_MAX = SparseIndexVector_Uint8Vector +}; + +inline const SparseIndexVector (&EnumValuesSparseIndexVector())[4] { + static const SparseIndexVector values[] = { + SparseIndexVector_NONE, + SparseIndexVector_Int32Vector, + SparseIndexVector_Uint16Vector, + SparseIndexVector_Uint8Vector + }; + return values; +} + +inline const char * const *EnumNamesSparseIndexVector() { + static const char * const names[] = { + "NONE", + "Int32Vector", + "Uint16Vector", + "Uint8Vector", + nullptr + }; + return names; +} + +inline const char *EnumNameSparseIndexVector(SparseIndexVector e) { + if (e < SparseIndexVector_NONE || e > SparseIndexVector_Uint8Vector) return ""; + const size_t index = static_cast<size_t>(e); + return EnumNamesSparseIndexVector()[index]; +} + +template<typename T> struct SparseIndexVectorTraits { + static const SparseIndexVector enum_value = SparseIndexVector_NONE; +}; + +template<> struct SparseIndexVectorTraits<Int32Vector> { + static const SparseIndexVector enum_value = SparseIndexVector_Int32Vector; +}; + +template<> struct SparseIndexVectorTraits<Uint16Vector> { + static const SparseIndexVector enum_value = SparseIndexVector_Uint16Vector; +}; + +template<> struct SparseIndexVectorTraits<Uint8Vector> { + static const SparseIndexVector enum_value = SparseIndexVector_Uint8Vector; +}; + +struct SparseIndexVectorUnion { + SparseIndexVector type; + void *value; + + SparseIndexVectorUnion() : type(SparseIndexVector_NONE), value(nullptr) {} + SparseIndexVectorUnion(SparseIndexVectorUnion&& u) FLATBUFFERS_NOEXCEPT : + type(SparseIndexVector_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } + SparseIndexVectorUnion(const SparseIndexVectorUnion &) FLATBUFFERS_NOEXCEPT; + SparseIndexVectorUnion &operator=(const SparseIndexVectorUnion &u) FLATBUFFERS_NOEXCEPT + { SparseIndexVectorUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + SparseIndexVectorUnion &operator=(SparseIndexVectorUnion &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } + ~SparseIndexVectorUnion() { Reset(); } + + void Reset(); + +#ifndef FLATBUFFERS_CPP98_STL + template <typename T> + void Set(T&& val) { + using RT = typename std::remove_reference<T>::type; + Reset(); + type = SparseIndexVectorTraits<typename RT::TableType>::enum_value; + if (type != SparseIndexVector_NONE) { + value = new RT(std::forward<T>(val)); + } + } +#endif // FLATBUFFERS_CPP98_STL + + static void *UnPack(const void *obj, SparseIndexVector type, const flatbuffers::resolver_function_t *resolver); + flatbuffers::Offset<void> Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + + Int32VectorT *AsInt32Vector() { + return type == SparseIndexVector_Int32Vector ? + reinterpret_cast<Int32VectorT *>(value) : nullptr; + } + const Int32VectorT *AsInt32Vector() const { + return type == SparseIndexVector_Int32Vector ? + reinterpret_cast<const Int32VectorT *>(value) : nullptr; + } + Uint16VectorT *AsUint16Vector() { + return type == SparseIndexVector_Uint16Vector ? + reinterpret_cast<Uint16VectorT *>(value) : nullptr; + } + const Uint16VectorT *AsUint16Vector() const { + return type == SparseIndexVector_Uint16Vector ? + reinterpret_cast<const Uint16VectorT *>(value) : nullptr; + } + Uint8VectorT *AsUint8Vector() { + return type == SparseIndexVector_Uint8Vector ? + reinterpret_cast<Uint8VectorT *>(value) : nullptr; + } + const Uint8VectorT *AsUint8Vector() const { + return type == SparseIndexVector_Uint8Vector ? + reinterpret_cast<const Uint8VectorT *>(value) : nullptr; + } +}; + +bool VerifySparseIndexVector(flatbuffers::Verifier &verifier, const void *obj, SparseIndexVector type); +bool VerifySparseIndexVectorVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types); + enum BuiltinOperator { BuiltinOperator_ADD = 0, BuiltinOperator_AVERAGE_POOL_2D = 1, @@ -2802,6 +2924,7 @@ inline flatbuffers::Offset<CustomQuantization> CreateCustomQuantization( inline flatbuffers::Offset<CustomQuantization> CreateCustomQuantizationDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector<uint8_t> *custom = nullptr) { + if (custom) { _fbb.ForceVectorAlignment(custom->size(), sizeof(uint8_t), 16); } auto custom__ = custom ? _fbb.CreateVector<uint8_t>(*custom) : 0; return tflite::CreateCustomQuantization( _fbb, @@ -2966,12 +3089,201 @@ inline flatbuffers::Offset<QuantizationParameters> CreateQuantizationParametersD flatbuffers::Offset<QuantizationParameters> CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct Int32VectorT : public flatbuffers::NativeTable { + typedef Int32Vector TableType; + std::vector<int32_t> values; + Int32VectorT() { + } +}; + +struct Int32Vector FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Int32VectorT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4 + }; + const flatbuffers::Vector<int32_t> *values() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_VALUES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + verifier.EndTable(); + } + Int32VectorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Int32VectorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<Int32Vector> Pack(flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Int32VectorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_values(flatbuffers::Offset<flatbuffers::Vector<int32_t>> values) { + fbb_.AddOffset(Int32Vector::VT_VALUES, values); + } + explicit Int32VectorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Int32VectorBuilder &operator=(const Int32VectorBuilder &); + flatbuffers::Offset<Int32Vector> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Int32Vector>(end); + return o; + } +}; + +inline flatbuffers::Offset<Int32Vector> CreateInt32Vector( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> values = 0) { + Int32VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Int32Vector> CreateInt32VectorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *values = nullptr) { + auto values__ = values ? _fbb.CreateVector<int32_t>(*values) : 0; + return tflite::CreateInt32Vector( + _fbb, + values__); +} + +flatbuffers::Offset<Int32Vector> CreateInt32Vector(flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Uint16VectorT : public flatbuffers::NativeTable { + typedef Uint16Vector TableType; + std::vector<uint16_t> values; + Uint16VectorT() { + } +}; + +struct Uint16Vector FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Uint16VectorT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4 + }; + const flatbuffers::Vector<uint16_t> *values() const { + return GetPointer<const flatbuffers::Vector<uint16_t> *>(VT_VALUES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + verifier.EndTable(); + } + Uint16VectorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Uint16VectorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<Uint16Vector> Pack(flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Uint16VectorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_values(flatbuffers::Offset<flatbuffers::Vector<uint16_t>> values) { + fbb_.AddOffset(Uint16Vector::VT_VALUES, values); + } + explicit Uint16VectorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Uint16VectorBuilder &operator=(const Uint16VectorBuilder &); + flatbuffers::Offset<Uint16Vector> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Uint16Vector>(end); + return o; + } +}; + +inline flatbuffers::Offset<Uint16Vector> CreateUint16Vector( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<uint16_t>> values = 0) { + Uint16VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Uint16Vector> CreateUint16VectorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<uint16_t> *values = nullptr) { + auto values__ = values ? _fbb.CreateVector<uint16_t>(*values) : 0; + return tflite::CreateUint16Vector( + _fbb, + values__); +} + +flatbuffers::Offset<Uint16Vector> CreateUint16Vector(flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Uint8VectorT : public flatbuffers::NativeTable { + typedef Uint8Vector TableType; + std::vector<uint8_t> values; + Uint8VectorT() { + } +}; + +struct Uint8Vector FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Uint8VectorT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4 + }; + const flatbuffers::Vector<uint8_t> *values() const { + return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_VALUES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + verifier.EndTable(); + } + Uint8VectorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Uint8VectorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<Uint8Vector> Pack(flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Uint8VectorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_values(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> values) { + fbb_.AddOffset(Uint8Vector::VT_VALUES, values); + } + explicit Uint8VectorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Uint8VectorBuilder &operator=(const Uint8VectorBuilder &); + flatbuffers::Offset<Uint8Vector> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<Uint8Vector>(end); + return o; + } +}; + +inline flatbuffers::Offset<Uint8Vector> CreateUint8Vector( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<uint8_t>> values = 0) { + Uint8VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline flatbuffers::Offset<Uint8Vector> CreateUint8VectorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<uint8_t> *values = nullptr) { + auto values__ = values ? _fbb.CreateVector<uint8_t>(*values) : 0; + return tflite::CreateUint8Vector( + _fbb, + values__); +} + +flatbuffers::Offset<Uint8Vector> CreateUint8Vector(flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct DimensionMetadataT : public flatbuffers::NativeTable { typedef DimensionMetadata TableType; DimensionType format; int32_t dense_size; - std::vector<int32_t> array_segments; - std::vector<int32_t> array_indices; + SparseIndexVectorUnion array_segments; + SparseIndexVectorUnion array_indices; DimensionMetadataT() : format(DimensionType_DENSE), dense_size(0) { @@ -2983,8 +3295,10 @@ struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FORMAT = 4, VT_DENSE_SIZE = 6, - VT_ARRAY_SEGMENTS = 8, - VT_ARRAY_INDICES = 10 + VT_ARRAY_SEGMENTS_TYPE = 8, + VT_ARRAY_SEGMENTS = 10, + VT_ARRAY_INDICES_TYPE = 12, + VT_ARRAY_INDICES = 14 }; DimensionType format() const { return static_cast<DimensionType>(GetField<int8_t>(VT_FORMAT, 0)); @@ -2992,20 +3306,48 @@ struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { int32_t dense_size() const { return GetField<int32_t>(VT_DENSE_SIZE, 0); } - const flatbuffers::Vector<int32_t> *array_segments() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_ARRAY_SEGMENTS); + SparseIndexVector array_segments_type() const { + return static_cast<SparseIndexVector>(GetField<uint8_t>(VT_ARRAY_SEGMENTS_TYPE, 0)); } - const flatbuffers::Vector<int32_t> *array_indices() const { - return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_ARRAY_INDICES); + const void *array_segments() const { + return GetPointer<const void *>(VT_ARRAY_SEGMENTS); + } + template<typename T> const T *array_segments_as() const; + const Int32Vector *array_segments_as_Int32Vector() const { + return array_segments_type() == SparseIndexVector_Int32Vector ? static_cast<const Int32Vector *>(array_segments()) : nullptr; + } + const Uint16Vector *array_segments_as_Uint16Vector() const { + return array_segments_type() == SparseIndexVector_Uint16Vector ? static_cast<const Uint16Vector *>(array_segments()) : nullptr; + } + const Uint8Vector *array_segments_as_Uint8Vector() const { + return array_segments_type() == SparseIndexVector_Uint8Vector ? static_cast<const Uint8Vector *>(array_segments()) : nullptr; + } + SparseIndexVector array_indices_type() const { + return static_cast<SparseIndexVector>(GetField<uint8_t>(VT_ARRAY_INDICES_TYPE, 0)); + } + const void *array_indices() const { + return GetPointer<const void *>(VT_ARRAY_INDICES); + } + template<typename T> const T *array_indices_as() const; + const Int32Vector *array_indices_as_Int32Vector() const { + return array_indices_type() == SparseIndexVector_Int32Vector ? static_cast<const Int32Vector *>(array_indices()) : nullptr; + } + const Uint16Vector *array_indices_as_Uint16Vector() const { + return array_indices_type() == SparseIndexVector_Uint16Vector ? static_cast<const Uint16Vector *>(array_indices()) : nullptr; + } + const Uint8Vector *array_indices_as_Uint8Vector() const { + return array_indices_type() == SparseIndexVector_Uint8Vector ? static_cast<const Uint8Vector *>(array_indices()) : nullptr; } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_FORMAT) && VerifyField<int32_t>(verifier, VT_DENSE_SIZE) && + VerifyField<uint8_t>(verifier, VT_ARRAY_SEGMENTS_TYPE) && VerifyOffset(verifier, VT_ARRAY_SEGMENTS) && - verifier.VerifyVector(array_segments()) && + VerifySparseIndexVector(verifier, array_segments(), array_segments_type()) && + VerifyField<uint8_t>(verifier, VT_ARRAY_INDICES_TYPE) && VerifyOffset(verifier, VT_ARRAY_INDICES) && - verifier.VerifyVector(array_indices()) && + VerifySparseIndexVector(verifier, array_indices(), array_indices_type()) && verifier.EndTable(); } DimensionMetadataT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -3013,6 +3355,30 @@ struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { static flatbuffers::Offset<DimensionMetadata> Pack(flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; +template<> inline const Int32Vector *DimensionMetadata::array_segments_as<Int32Vector>() const { + return array_segments_as_Int32Vector(); +} + +template<> inline const Uint16Vector *DimensionMetadata::array_segments_as<Uint16Vector>() const { + return array_segments_as_Uint16Vector(); +} + +template<> inline const Uint8Vector *DimensionMetadata::array_segments_as<Uint8Vector>() const { + return array_segments_as_Uint8Vector(); +} + +template<> inline const Int32Vector *DimensionMetadata::array_indices_as<Int32Vector>() const { + return array_indices_as_Int32Vector(); +} + +template<> inline const Uint16Vector *DimensionMetadata::array_indices_as<Uint16Vector>() const { + return array_indices_as_Uint16Vector(); +} + +template<> inline const Uint8Vector *DimensionMetadata::array_indices_as<Uint8Vector>() const { + return array_indices_as_Uint8Vector(); +} + struct DimensionMetadataBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -3022,10 +3388,16 @@ struct DimensionMetadataBuilder { void add_dense_size(int32_t dense_size) { fbb_.AddElement<int32_t>(DimensionMetadata::VT_DENSE_SIZE, dense_size, 0); } - void add_array_segments(flatbuffers::Offset<flatbuffers::Vector<int32_t>> array_segments) { + void add_array_segments_type(SparseIndexVector array_segments_type) { + fbb_.AddElement<uint8_t>(DimensionMetadata::VT_ARRAY_SEGMENTS_TYPE, static_cast<uint8_t>(array_segments_type), 0); + } + void add_array_segments(flatbuffers::Offset<void> array_segments) { fbb_.AddOffset(DimensionMetadata::VT_ARRAY_SEGMENTS, array_segments); } - void add_array_indices(flatbuffers::Offset<flatbuffers::Vector<int32_t>> array_indices) { + void add_array_indices_type(SparseIndexVector array_indices_type) { + fbb_.AddElement<uint8_t>(DimensionMetadata::VT_ARRAY_INDICES_TYPE, static_cast<uint8_t>(array_indices_type), 0); + } + void add_array_indices(flatbuffers::Offset<void> array_indices) { fbb_.AddOffset(DimensionMetadata::VT_ARRAY_INDICES, array_indices); } explicit DimensionMetadataBuilder(flatbuffers::FlatBufferBuilder &_fbb) @@ -3044,32 +3416,20 @@ inline flatbuffers::Offset<DimensionMetadata> CreateDimensionMetadata( flatbuffers::FlatBufferBuilder &_fbb, DimensionType format = DimensionType_DENSE, int32_t dense_size = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> array_segments = 0, - flatbuffers::Offset<flatbuffers::Vector<int32_t>> array_indices = 0) { + SparseIndexVector array_segments_type = SparseIndexVector_NONE, + flatbuffers::Offset<void> array_segments = 0, + SparseIndexVector array_indices_type = SparseIndexVector_NONE, + flatbuffers::Offset<void> array_indices = 0) { DimensionMetadataBuilder builder_(_fbb); builder_.add_array_indices(array_indices); builder_.add_array_segments(array_segments); builder_.add_dense_size(dense_size); + builder_.add_array_indices_type(array_indices_type); + builder_.add_array_segments_type(array_segments_type); builder_.add_format(format); return builder_.Finish(); } -inline flatbuffers::Offset<DimensionMetadata> CreateDimensionMetadataDirect( - flatbuffers::FlatBufferBuilder &_fbb, - DimensionType format = DimensionType_DENSE, - int32_t dense_size = 0, - const std::vector<int32_t> *array_segments = nullptr, - const std::vector<int32_t> *array_indices = nullptr) { - auto array_segments__ = array_segments ? _fbb.CreateVector<int32_t>(*array_segments) : 0; - auto array_indices__ = array_indices ? _fbb.CreateVector<int32_t>(*array_indices) : 0; - return tflite::CreateDimensionMetadata( - _fbb, - format, - dense_size, - array_segments__, - array_indices__); -} - flatbuffers::Offset<DimensionMetadata> CreateDimensionMetadata(flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SparsityParametersT : public flatbuffers::NativeTable { @@ -9896,6 +10256,7 @@ inline flatbuffers::Offset<Buffer> CreateBuffer( inline flatbuffers::Offset<Buffer> CreateBufferDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector<uint8_t> *data = nullptr) { + if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 16); } auto data__ = data ? _fbb.CreateVector<uint8_t>(*data) : 0; return tflite::CreateBuffer( _fbb, @@ -10157,6 +10518,7 @@ inline flatbuffers::Offset<CustomQuantization> CreateCustomQuantization(flatbuff (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CustomQuantizationT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + _fbb.ForceVectorAlignment(_o->custom.size(), sizeof(uint8_t), 16); auto _custom = _o->custom.size() ? _fbb.CreateVector(_o->custom) : 0; return tflite::CreateCustomQuantization( _fbb, @@ -10207,6 +10569,84 @@ inline flatbuffers::Offset<QuantizationParameters> CreateQuantizationParameters( _quantized_dimension); } +inline Int32VectorT *Int32Vector::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new Int32VectorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Int32Vector::UnPackTo(Int32VectorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset<Int32Vector> Int32Vector::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateInt32Vector(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<Int32Vector> CreateInt32Vector(flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Int32VectorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values = _o->values.size() ? _fbb.CreateVector(_o->values) : 0; + return tflite::CreateInt32Vector( + _fbb, + _values); +} + +inline Uint16VectorT *Uint16Vector::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new Uint16VectorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Uint16Vector::UnPackTo(Uint16VectorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset<Uint16Vector> Uint16Vector::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateUint16Vector(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<Uint16Vector> CreateUint16Vector(flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Uint16VectorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values = _o->values.size() ? _fbb.CreateVector(_o->values) : 0; + return tflite::CreateUint16Vector( + _fbb, + _values); +} + +inline Uint8VectorT *Uint8Vector::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new Uint8VectorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Uint8Vector::UnPackTo(Uint8VectorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset<Uint8Vector> Uint8Vector::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateUint8Vector(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<Uint8Vector> CreateUint8Vector(flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Uint8VectorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values = _o->values.size() ? _fbb.CreateVector(_o->values) : 0; + return tflite::CreateUint8Vector( + _fbb, + _values); +} + inline DimensionMetadataT *DimensionMetadata::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new DimensionMetadataT(); UnPackTo(_o, _resolver); @@ -10218,8 +10658,10 @@ inline void DimensionMetadata::UnPackTo(DimensionMetadataT *_o, const flatbuffer (void)_resolver; { auto _e = format(); _o->format = _e; }; { auto _e = dense_size(); _o->dense_size = _e; }; - { auto _e = array_segments(); if (_e) { _o->array_segments.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->array_segments[_i] = _e->Get(_i); } } }; - { auto _e = array_indices(); if (_e) { _o->array_indices.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->array_indices[_i] = _e->Get(_i); } } }; + { auto _e = array_segments_type(); _o->array_segments.type = _e; }; + { auto _e = array_segments(); if (_e) _o->array_segments.value = SparseIndexVectorUnion::UnPack(_e, array_segments_type(), _resolver); }; + { auto _e = array_indices_type(); _o->array_indices.type = _e; }; + { auto _e = array_indices(); if (_e) _o->array_indices.value = SparseIndexVectorUnion::UnPack(_e, array_indices_type(), _resolver); }; } inline flatbuffers::Offset<DimensionMetadata> DimensionMetadata::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10232,13 +10674,17 @@ inline flatbuffers::Offset<DimensionMetadata> CreateDimensionMetadata(flatbuffer struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DimensionMetadataT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _format = _o->format; auto _dense_size = _o->dense_size; - auto _array_segments = _o->array_segments.size() ? _fbb.CreateVector(_o->array_segments) : 0; - auto _array_indices = _o->array_indices.size() ? _fbb.CreateVector(_o->array_indices) : 0; + auto _array_segments_type = _o->array_segments.type; + auto _array_segments = _o->array_segments.Pack(_fbb); + auto _array_indices_type = _o->array_indices.type; + auto _array_indices = _o->array_indices.Pack(_fbb); return tflite::CreateDimensionMetadata( _fbb, _format, _dense_size, + _array_segments_type, _array_segments, + _array_indices_type, _array_indices); } @@ -13082,6 +13528,7 @@ inline flatbuffers::Offset<Buffer> CreateBuffer(flatbuffers::FlatBufferBuilder & (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BufferT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + _fbb.ForceVectorAlignment(_o->data.size(), sizeof(uint8_t), 16); auto _data = _o->data.size() ? _fbb.CreateVector(_o->data) : 0; return tflite::CreateBuffer( _fbb, @@ -13230,6 +13677,117 @@ inline void QuantizationDetailsUnion::Reset() { type = QuantizationDetails_NONE; } +inline bool VerifySparseIndexVector(flatbuffers::Verifier &verifier, const void *obj, SparseIndexVector type) { + switch (type) { + case SparseIndexVector_NONE: { + return true; + } + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast<const Int32Vector *>(obj); + return verifier.VerifyTable(ptr); + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast<const Uint16Vector *>(obj); + return verifier.VerifyTable(ptr); + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast<const Uint8Vector *>(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifySparseIndexVectorVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector<flatbuffers::Offset<void>> *values, const flatbuffers::Vector<uint8_t> *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifySparseIndexVector( + verifier, values->Get(i), types->GetEnum<SparseIndexVector>(i))) { + return false; + } + } + return true; +} + +inline void *SparseIndexVectorUnion::UnPack(const void *obj, SparseIndexVector type, const flatbuffers::resolver_function_t *resolver) { + switch (type) { + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast<const Int32Vector *>(obj); + return ptr->UnPack(resolver); + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast<const Uint16Vector *>(obj); + return ptr->UnPack(resolver); + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast<const Uint8Vector *>(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; + } +} + +inline flatbuffers::Offset<void> SparseIndexVectorUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const { + switch (type) { + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast<const Int32VectorT *>(value); + return CreateInt32Vector(_fbb, ptr, _rehasher).Union(); + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast<const Uint16VectorT *>(value); + return CreateUint16Vector(_fbb, ptr, _rehasher).Union(); + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast<const Uint8VectorT *>(value); + return CreateUint8Vector(_fbb, ptr, _rehasher).Union(); + } + default: return 0; + } +} + +inline SparseIndexVectorUnion::SparseIndexVectorUnion(const SparseIndexVectorUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) { + switch (type) { + case SparseIndexVector_Int32Vector: { + value = new Int32VectorT(*reinterpret_cast<Int32VectorT *>(u.value)); + break; + } + case SparseIndexVector_Uint16Vector: { + value = new Uint16VectorT(*reinterpret_cast<Uint16VectorT *>(u.value)); + break; + } + case SparseIndexVector_Uint8Vector: { + value = new Uint8VectorT(*reinterpret_cast<Uint8VectorT *>(u.value)); + break; + } + default: + break; + } +} + +inline void SparseIndexVectorUnion::Reset() { + switch (type) { + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast<Int32VectorT *>(value); + delete ptr; + break; + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast<Uint16VectorT *>(value); + delete ptr; + break; + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast<Uint8VectorT *>(value); + delete ptr; + break; + } + default: break; + } + value = nullptr; + type = SparseIndexVector_NONE; +} + inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type) { switch (type) { case BuiltinOptions_NONE: { diff --git a/tensorflow/lite/testdata/sparse_tensor.bin b/tensorflow/lite/testdata/sparse_tensor.bin index 497ce68a3ac..c035e02441d 100644 Binary files a/tensorflow/lite/testdata/sparse_tensor.bin and b/tensorflow/lite/testdata/sparse_tensor.bin differ diff --git a/tensorflow/lite/testdata/sparse_tensor.json b/tensorflow/lite/testdata/sparse_tensor.json index ce627e2bb2d..d23c0d0a64b 100644 --- a/tensorflow/lite/testdata/sparse_tensor.json +++ b/tensorflow/lite/testdata/sparse_tensor.json @@ -25,8 +25,10 @@ }, { "format": "SPARSE_CSR", - "array_segments": [0, 2, 3], - "array_indices": [0, 1, 1] + "array_segments_type": "Uint8Vector", + "array_segments": {"values": [0, 2, 3]}, + "array_indices_type": "Uint8Vector", + "array_indices": {"values": [0, 1, 1]} }, { "format": "DENSE", diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index d9b737ba77e..f5b369ea501 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -113,6 +113,65 @@ bool VerifyStringTensorBuffer(const Tensor& tensor, const Buffer& buffer, return true; } +int GetSizeOfSegments(const DimensionMetadata* dim_metadata) { + switch (dim_metadata->array_segments_type()) { + case SparseIndexVector_Int32Vector: + return dim_metadata->array_segments_as_Int32Vector()->values()->size(); + case SparseIndexVector_Uint16Vector: + return dim_metadata->array_segments_as_Uint16Vector()->values()->size(); + case SparseIndexVector_Uint8Vector: + return dim_metadata->array_segments_as_Uint8Vector()->values()->size(); + default: + return -1; + } +} + +int GetValueOfSegmentsAt(const DimensionMetadata* dim_metadata, const int i) { + switch (dim_metadata->array_segments_type()) { + case SparseIndexVector_Int32Vector: + return static_cast<int>( + dim_metadata->array_segments_as_Int32Vector()->values()->Get(i)); + case SparseIndexVector_Uint16Vector: + return static_cast<int>( + dim_metadata->array_segments_as_Uint16Vector()->values()->Get(i)); + case SparseIndexVector_Uint8Vector: + return static_cast<int>( + dim_metadata->array_segments_as_Uint8Vector()->values()->Get(i)); + default: + return -1; + } +} + +int GetSizeOfIndices(const DimensionMetadata* dim_metadata) { + switch (dim_metadata->array_indices_type()) { + case SparseIndexVector_Int32Vector: + return dim_metadata->array_indices_as_Int32Vector()->values()->size(); + case SparseIndexVector_Uint16Vector: + return dim_metadata->array_indices_as_Uint16Vector()->values()->size(); + case SparseIndexVector_Uint8Vector: + return dim_metadata->array_indices_as_Uint8Vector()->values()->size(); + default: + return -1; + } +} + +int GetValueOfIndicesAt(const DimensionMetadata* dim_metadata, const int i) { + switch (dim_metadata->array_indices_type()) { + case SparseIndexVector_Int32Vector: + return static_cast<int>( + dim_metadata->array_indices_as_Int32Vector()->values()->Get(i)); + case SparseIndexVector_Uint16Vector: + return static_cast<int>( + dim_metadata->array_indices_as_Uint16Vector()->values()->Get(i)); + case SparseIndexVector_Uint8Vector: + return static_cast<int>( + dim_metadata->array_indices_as_Uint8Vector()->values()->Get(i)); + default: + return -1; + } + return -1; +} + // The sparsity parameter defines a tree structure to map each non-zero element // stored in the flattened buffer back to its index in the conceptual dense // tensor. @@ -139,31 +198,36 @@ absl::optional<uint64_t> VerifyAndCountElements( return absl::nullopt; } - for (int j = 0; j < array_segments->size() - 1; j++) { - if (array_segments->Get(j) < 0 || array_segments->Get(j + 1) < 0 || - array_segments->Get(j) > array_segments->Get(j + 1)) { + int array_segments_size = GetSizeOfSegments(dim_metadata); + int array_indices_size = GetSizeOfIndices(dim_metadata); + + for (int j = 0; j < array_segments_size - 1; j++) { + if (GetValueOfSegmentsAt(dim_metadata, j) < 0 || + GetValueOfSegmentsAt(dim_metadata, j + 1) < 0 || + GetValueOfSegmentsAt(dim_metadata, j) > + GetValueOfSegmentsAt(dim_metadata, j + 1)) { return absl::nullopt; } } - if (num_elements != array_segments->size() - 1) { + if (num_elements != array_segments_size - 1) { return absl::nullopt; } - if (array_indices->size() != - array_segments->Get(array_segments->size() - 1)) { + if (array_indices_size != + GetValueOfSegmentsAt(dim_metadata, array_segments_size - 1)) { return absl::nullopt; } - for (int j = 0; j < array_indices->size(); j++) { - if (array_indices->Get(j) < 0 || - array_indices->Get(j) >= dim_sizes[original_dim]) { + for (int j = 0; j < array_indices_size; j++) { + if (GetValueOfIndicesAt(dim_metadata, j) < 0 || + GetValueOfIndicesAt(dim_metadata, j) >= dim_sizes[original_dim]) { return absl::nullopt; } } // Need to reset num_elements when seeing a sparse dimension. - num_elements = array_indices->size(); + num_elements = array_indices_size; } } diff --git a/tensorflow/lite/tools/verifier_test.cc b/tensorflow/lite/tools/verifier_test.cc index 355ee6640c6..1e13fda7c33 100644 --- a/tensorflow/lite/tools/verifier_test.cc +++ b/tensorflow/lite/tools/verifier_test.cc @@ -613,7 +613,8 @@ TEST(VerifyModel, InvalidSparseTensorIndexOutOfBound) { scoped_model.reset(model->GetModel()->UnPack()); auto* tensor = scoped_model->subgraphs[0]->tensors[0].get(); - tensor->sparsity->dim_metadata[1]->array_indices[1] = 5; + tensor->sparsity->dim_metadata[1]->array_indices.AsUint8Vector()->values[1] = + 5; flatbuffers::FlatBufferBuilder builder; auto model_ = Model::Pack(builder, scoped_model.get()); @@ -693,8 +694,10 @@ TEST(VerifyModel, ValidSparseTensorBCSC) { tensor->sparsity->dim_metadata[0]->dense_size = 2; tensor->sparsity->dim_metadata[1]->format = DimensionType_SPARSE_CSR; - tensor->sparsity->dim_metadata[1]->array_segments = {0, 1, 3}; - tensor->sparsity->dim_metadata[1]->array_indices = {0, 0, 1}; + tensor->sparsity->dim_metadata[1]->array_segments.AsUint8Vector()->values = { + 0, 1, 3}; + tensor->sparsity->dim_metadata[1]->array_indices.AsUint8Vector()->values = { + 0, 0, 1}; tensor->sparsity->dim_metadata[2]->format = DimensionType_DENSE; tensor->sparsity->dim_metadata[2]->dense_size = 2;