Add support for unknown dimensions to TFLite using MLIR converter.

PiperOrigin-RevId: 292563455
Change-Id: Ib5700cfe6faee177027329e32089abb3bcc9adaf
This commit is contained in:
Nupur Garg 2020-01-31 09:53:26 -08:00 committed by TensorFlower Gardener
parent 4e20c32249
commit 55912083e2
17 changed files with 284 additions and 40 deletions

View File

@ -610,6 +610,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
};
std::vector<int32_t> shape;
std::vector<int32_t> shape_signature;
if (type.hasStaticShape()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
@ -627,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
} else if (type.hasRank()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape.reserve(shape_ref.size());
for (auto& dim : shape_ref) {
shape.push_back(dim == -1 ? 1 : dim);
}
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
Type element_type = type.getElementType();
tflite::TensorType tflite_element_type =
GetTFLiteType(type.getElementType()).ValueOrDie();
@ -664,10 +675,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
break;
}
}
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
if (shape_signature.empty()) {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
} else {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable, /*sparsity=*/0,
/*shape_signature=*/builder_.CreateVector(shape_signature));
}
}
BufferOffset<tflite::Operator> Translator::BuildIfOperator(

View File

@ -140,6 +140,11 @@ void TfLiteTensorFree(TfLiteTensor* t) {
if (t->dims) TfLiteIntArrayFree(t->dims);
t->dims = NULL;
if (t->dims_signature) {
TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
}
t->dims_signature = NULL;
TfLiteQuantizationFree(&t->quantization);
TfLiteSparsityFree(t->sparsity);
t->sparsity = NULL;

View File

@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
// This is optional. The field is NULL if a tensor is dense.
// WARNING: This is an experimental interface that is subject to change.
TfLiteSparsity* sparsity;
// Optional. Encodes shapes with unknown dimensions with -1. This field is
// only populated when unknown dimensions exist in a read-write tensor (i.e.
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
// `dims_signature` contains [1, -1, -1, 3]).
const TfLiteIntArray* dims_signature;
} TfLiteTensor;
#ifndef TF_LITE_STATIC_MEMORY

View File

@ -95,6 +95,7 @@ TEST(Quantization, TestQuantizationFree) {
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
t.allocation_type = kTfLiteArenaRw;
t.dims = nullptr;
t.dims_signature = nullptr;
t.quantization.type = kTfLiteAffineQuantization;
t.sparsity = nullptr;
auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
@ -110,6 +111,7 @@ TEST(Sparsity, TestSparsityFree) {
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
t.allocation_type = kTfLiteArenaRw;
t.dims = nullptr;
t.dims_signature = nullptr;
// A dummy CSR sparse matrix.
t.sparsity = static_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));

View File

@ -1074,7 +1074,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
// to Interpreter.
TfLiteStatus Subgraph::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantization quantization, bool is_variable) {
const int* dims, TfLiteQuantization quantization, bool is_variable,
const size_t rank_dims_signature, const int* dims_signature) {
// Ensure quantization cleanup on failure.
ScopedTfLiteQuantization scoped_quantization(&quantization);
if (state_ == kStateInvokableAndImmutable) {
@ -1114,6 +1115,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
// TODO(suharshs): Update TfLiteTensorReset to include the new quantization
// if there are other required callers.
tensor.quantization = *scoped_quantization.release();
tensor.dims_signature =
ConvertArrayToTfLiteIntArray(rank_dims_signature, dims_signature);
return kTfLiteOk;
}

View File

@ -114,15 +114,17 @@ class Subgraph {
inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantization quantization,
bool is_variable = false) {
bool is_variable = false, const size_t rank_dims_signature = 0,
const int* dims_signature = nullptr) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
dims.data(), quantization, is_variable);
dims.data(), quantization, is_variable,
rank_dims_signature, dims_signature);
}
TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
const char* name, const size_t rank,
const int* dims,
TfLiteQuantization quantization,
bool is_variable = false);
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantization quantization,
bool is_variable = false, const size_t rank_dims_signature = 0,
const int* dims_signature = nullptr);
// WARNING: Experimental interface, subject to change
// Overrides execution plan. This bounds checks indices sent in.

View File

@ -563,6 +563,13 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError;
}
size_t dims_signature_rank = 0;
const int* dims_signature_data = nullptr;
if (tensor->shape_signature()) {
dims_signature_rank = tensor->shape_signature()->Length();
dims_signature_data = tensor->shape_signature()->data();
}
bool is_variable = tensor->is_variable();
if (buffer_ptr) {
if (is_variable) {
@ -590,9 +597,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError;
}
} else {
if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor),
dims, quantization,
is_variable) != kTfLiteOk) {
if (subgraph->SetTensorParametersReadWrite(
i, type, get_name(tensor), dims, quantization, is_variable,
dims_signature_rank, dims_signature_data) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;

View File

@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export as _tf_export
@ -384,7 +385,16 @@ def build_toco_convert_protos(input_tensors,
shape = input_tensor.shape
else:
shape = input_shapes[idx]
input_array.shape.dims.extend(list(map(int, shape)))
# Create shapes with -1 for unknown dimensions.
dims = []
for dim in shape:
if (dim is None or
(isinstance(dim, tensor_shape.Dimension) and dim.value is None)):
dims.append(-1)
else:
dims.append(int(dim))
input_array.shape.dims.extend(dims)
for output_tensor in output_tensors:
model.output_arrays.append(util.get_tensor_name(output_tensor))

View File

@ -320,6 +320,7 @@ class Interpreter(object):
tensor_index = int(tensor_index)
tensor_name = self._interpreter.TensorName(tensor_index)
tensor_size = self._interpreter.TensorSize(tensor_index)
tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index)
tensor_type = self._interpreter.TensorType(tensor_index)
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
@ -332,6 +333,7 @@ class Interpreter(object):
'name': tensor_name,
'index': tensor_index,
'shape': tensor_size,
'shape_signature': tensor_size_signature,
'dtype': tensor_type,
'quantization': tensor_quantization,
'quantization_parameters': {

View File

@ -301,6 +301,23 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
const int32_t* size_signature_data = nullptr;
int32_t size_signature_size = 0;
if (tensor->dims_signature != nullptr) {
size_signature_data = tensor->dims_signature->data;
size_signature_size = tensor->dims_signature->size;
}
PyObject* np_array =
PyArrayFromIntVector(size_signature_data, size_signature_size);
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);

View File

@ -69,6 +69,7 @@ class InterpreterWrapper {
std::string TensorName(int i) const;
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;
PyObject* TensorSizeSignature(int i) const;
// Deprecated in favor of TensorQuantizationScales, below.
PyObject* TensorQuantization(int i) const;
PyObject* TensorQuantizationParameters(int i) const;

View File

@ -261,6 +261,16 @@ class TFLiteConverterBase(object):
self.representative_dataset.input_gen, inference_input_type,
inference_output_type, allow_float, enable_mlir_quantizer)
def _is_unknown_shapes_allowed(self):
# TODO(b/128319310): Investigate which quantization methods work.
if self._any_optimization_enabled():
return False
# Unknown dimensions are only allowed with the new converter.
if not self.experimental_new_converter:
return False
return True
def _get_base_converter_args(self):
"""Returns the base converter args.
@ -456,19 +466,21 @@ class TFLiteConverterV2(TFLiteConverterBase):
config=self._grappler_config(),
graph=frozen_func.graph)
# Checks dimensions in input tensor.
for tensor in input_tensors:
# Note that shape_list might be empty for scalar shapes.
shape_list = tensor.shape.as_list()
if None in shape_list[1:]:
raise ValueError(
"None is only supported in the 1st dimension. Tensor '{0}' has "
"invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list))
elif shape_list and shape_list[0] is None:
# Set the batch size to 1 if undefined.
shape = tensor.shape.as_list()
shape[0] = 1
tensor.set_shape(shape)
if not self._is_unknown_shapes_allowed():
# Checks dimensions in input tensor.
for tensor in input_tensors:
# Note that shape_list might be empty for scalar shapes.
shape_list = tensor.shape.as_list()
if None in shape_list[1:]:
raise ValueError(
"None is only supported in the 1st dimension. Tensor '{0}' has "
"invalid shape '{1}'.".format(
_get_tensor_name(tensor), shape_list))
elif shape_list and shape_list[0] is None:
# Set the batch size to 1 if undefined.
shape = tensor.shape.as_list()
shape[0] = 1
tensor.set_shape(shape)
self._validate_quantization()
self._validate_representative_dataset()
@ -942,7 +954,7 @@ class TFLiteConverter(TFLiteConverterBase):
None value for dimension in input_tensor.
"""
# Checks dimensions in input tensor.
if self._has_valid_tensors():
if not self._is_unknown_shapes_allowed() and self._has_valid_tensors():
for tensor in self._input_tensors:
shape = tensor.shape
if not shape:
@ -1115,6 +1127,20 @@ class TFLiteConverter(TFLiteConverterBase):
shape[0] = batch_size
tensor.set_shape(shape)
def _is_unknown_shapes_allowed(self):
if not super(TFLiteConverter, self)._is_unknown_shapes_allowed():
return False
# `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
# the MLIR converter.
if self.conversion_summary_dir:
logging.warning(
"`conversion_summary_dir` does not work with unknown shapes. "
"Graphs with unknown shapes might be different than when this flag "
"is disabled.")
return False
return True
@_tf_export(v1=["lite.TocoConverter"])
class TocoConverter(object):

View File

@ -318,9 +318,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Test None as shape.
# Test None as shape when dynamic shapes are disabled. Run with TOCO in
# order to invoke shape checking code.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.experimental_new_converter = False
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
@ -375,9 +377,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Test invalid shape. None after 1st dimension.
# Test invalid shape. None after 1st dimension. Run with TOCO in order to
# invoke shape checking code.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.experimental_new_converter = False
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual(
@ -385,6 +389,44 @@ class FromSessionTest(TestModels, parameterized.TestCase):
'\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
str(error.exception))
def testSizeNone(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, None, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Test None after 1st dimension.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.experimental_new_converter = True
tflite_model = converter.convert()
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 1, 16, 3] == input_details[0]['shape']).all())
self.assertTrue(([1, -1, 16,
3] == input_details[0]['shape_signature']).all())
self.assertEqual((0., 0.), input_details[0]['quantization'])
# Resize tensor and invoke.
interpreter.resize_tensor_input(0, [1, 16, 16, 3])
interpreter.allocate_tensors()
interpreter.invoke()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertTrue(([1, -1, 16,
3] == input_details[0]['shape_signature']).all())
output_details = interpreter.get_output_details()
self.assertFalse(output_details[0]['shape_signature'])
def testBatchSizeValid(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(

View File

@ -54,12 +54,28 @@ from tensorflow.python.training.tracking import tracking
class TestModels(test_util.TensorFlowTestCase, parameterized.TestCase):
def _evaluateTFLiteModel(self, tflite_model, input_data):
"""Evaluates the model on the `input_data`."""
def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
"""Evaluates the model on the `input_data`.
Args:
tflite_model: TensorFlow Lite model.
input_data: List of EagerTensor const ops containing the input data for
each input tensor.
input_shapes: List of tuples representing the `shape_signature` and the
new shape of each input tensor that has unknown dimensions.
Returns:
[np.ndarray]
"""
interpreter = Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
if input_shapes:
for idx, (shape_signature, final_shape) in enumerate(input_shapes):
self.assertTrue(
(input_details[idx]['shape_signature'] == shape_signature).all())
interpreter.resize_tensor_input(idx, final_shape)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for input_tensor, tensor_data in zip(input_details, input_data):
@ -795,5 +811,62 @@ class GrapplerTest(TestModels):
actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
class UnknownShapes(TestModels):
@test_util.run_v2_only
def testMatMul(self):
input_data = constant_op.constant(
np.array(np.random.random_sample((10, 4)), dtype=np.float32))
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[None, 4], dtype=dtypes.float32)
])
def model(in_tensor):
shape = array_ops.shape_v2(in_tensor)
fill = array_ops.transpose_v2(array_ops.fill(shape, 1.))
return math_ops.matmul(fill, in_tensor)
concrete_func = model.get_concrete_function()
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(input_data)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])
np.testing.assert_almost_equal(
expected_value.numpy(), actual_value[0], decimal=6)
def testBatchMatMul(self):
self.skipTest('BatchMatMulV2 ranked tensor check fails.')
input_data_1 = constant_op.constant(
np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
input_data_2 = constant_op.constant(
np.array(np.random.random_sample((1, 2, 256)), dtype=np.float32))
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[1, 256, 256], dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=[1, None, 256], dtype=dtypes.float32)
])
def model(in_tensor_1, in_tensor_2):
return math_ops.matmul(in_tensor_1, in_tensor_2)
concrete_func = model.get_concrete_function()
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
tflite_model = converter.convert()
# Check values from converted model.
expected_value = concrete_func(input_data_1, input_data_2)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data_1, input_data_2],
input_shapes={1: [1, 2, 256]})
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
if __name__ == '__main__':
test.main()

View File

@ -178,6 +178,10 @@ table Tensor {
// Parameters to encode a sparse tensor. See the example in
// tensorflow/lite/testdata/sparse_tensor.json.
sparsity:SparsityParameters; // Optional.
// Encodes `shape` with unknown dimensions. Unknown dimensions are
// represented with -1.
shape_signature:[int]; // Optional.
}
// A list of builtin operators. Builtin operators are slightly faster than custom

View File

@ -3175,6 +3175,7 @@ struct TensorT : public flatbuffers::NativeTable {
std::unique_ptr<QuantizationParametersT> quantization;
bool is_variable;
std::unique_ptr<SparsityParametersT> sparsity;
std::vector<int32_t> shape_signature;
TensorT()
: type(TensorType_FLOAT32),
buffer(0),
@ -3191,7 +3192,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_NAME = 10,
VT_QUANTIZATION = 12,
VT_IS_VARIABLE = 14,
VT_SPARSITY = 16
VT_SPARSITY = 16,
VT_SHAPE_SIGNATURE = 18
};
const flatbuffers::Vector<int32_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
@ -3214,6 +3216,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const SparsityParameters *sparsity() const {
return GetPointer<const SparsityParameters *>(VT_SPARSITY);
}
const flatbuffers::Vector<int32_t> *shape_signature() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SHAPE) &&
@ -3227,6 +3232,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
VerifyOffset(verifier, VT_SPARSITY) &&
verifier.VerifyTable(sparsity()) &&
VerifyOffset(verifier, VT_SHAPE_SIGNATURE) &&
verifier.VerifyVector(shape_signature()) &&
verifier.EndTable();
}
TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -3258,6 +3265,9 @@ struct TensorBuilder {
void add_sparsity(flatbuffers::Offset<SparsityParameters> sparsity) {
fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity);
}
void add_shape_signature(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature) {
fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature);
}
explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@ -3278,8 +3288,10 @@ inline flatbuffers::Offset<Tensor> CreateTensor(
flatbuffers::Offset<flatbuffers::String> name = 0,
flatbuffers::Offset<QuantizationParameters> quantization = 0,
bool is_variable = false,
flatbuffers::Offset<SparsityParameters> sparsity = 0) {
flatbuffers::Offset<SparsityParameters> sparsity = 0,
flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0) {
TensorBuilder builder_(_fbb);
builder_.add_shape_signature(shape_signature);
builder_.add_sparsity(sparsity);
builder_.add_quantization(quantization);
builder_.add_name(name);
@ -3298,9 +3310,11 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
const char *name = nullptr,
flatbuffers::Offset<QuantizationParameters> quantization = 0,
bool is_variable = false,
flatbuffers::Offset<SparsityParameters> sparsity = 0) {
flatbuffers::Offset<SparsityParameters> sparsity = 0,
const std::vector<int32_t> *shape_signature = nullptr) {
auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
auto name__ = name ? _fbb.CreateString(name) : 0;
auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0;
return tflite::CreateTensor(
_fbb,
shape__,
@ -3309,7 +3323,8 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
name__,
quantization,
is_variable,
sparsity);
sparsity,
shape_signature__);
}
flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@ -10275,6 +10290,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t
{ auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); };
{ auto _e = is_variable(); _o->is_variable = _e; };
{ auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr<SparsityParametersT>(_e->UnPack(_resolver)); };
{ auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } };
}
inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -10292,6 +10308,7 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0;
auto _is_variable = _o->is_variable;
auto _sparsity = _o->sparsity ? CreateSparsityParameters(_fbb, _o->sparsity.get(), _rehasher) : 0;
auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0;
return tflite::CreateTensor(
_fbb,
_shape,
@ -10300,7 +10317,8 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
_name,
_quantization,
_is_variable,
_sparsity);
_sparsity,
_shape_signature);
}
inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {

View File

@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
// This is optional. The field is NULL if a tensor is dense.
// WARNING: This is an experimental interface that is subject to change.
TfLiteSparsity* sparsity;
// Optional. Encodes shapes with unknown dimensions with -1. This field is
// only populated when unknown dimensions exist in a read-write tensor (i.e.
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
// `dims_signature` contains [1, -1, -1, 3]).
const TfLiteIntArray* dims_signature;
} TfLiteTensor;
#ifndef TF_LITE_STATIC_MEMORY