Add support for unknown dimensions to TFLite using MLIR converter.
PiperOrigin-RevId: 292563455 Change-Id: Ib5700cfe6faee177027329e32089abb3bcc9adaf
This commit is contained in:
parent
4e20c32249
commit
55912083e2
@ -610,6 +610,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
};
|
||||
|
||||
std::vector<int32_t> shape;
|
||||
std::vector<int32_t> shape_signature;
|
||||
if (type.hasStaticShape()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
@ -627,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
} else if (type.hasRank()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape.reserve(shape_ref.size());
|
||||
for (auto& dim : shape_ref) {
|
||||
shape.push_back(dim == -1 ? 1 : dim);
|
||||
}
|
||||
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
|
||||
Type element_type = type.getElementType();
|
||||
tflite::TensorType tflite_element_type =
|
||||
GetTFLiteType(type.getElementType()).ValueOrDie();
|
||||
@ -664,10 +675,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
break;
|
||||
}
|
||||
}
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
|
||||
if (shape_signature.empty()) {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
} else {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable, /*sparsity=*/0,
|
||||
/*shape_signature=*/builder_.CreateVector(shape_signature));
|
||||
}
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||
|
@ -140,6 +140,11 @@ void TfLiteTensorFree(TfLiteTensor* t) {
|
||||
if (t->dims) TfLiteIntArrayFree(t->dims);
|
||||
t->dims = NULL;
|
||||
|
||||
if (t->dims_signature) {
|
||||
TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
|
||||
}
|
||||
t->dims_signature = NULL;
|
||||
|
||||
TfLiteQuantizationFree(&t->quantization);
|
||||
TfLiteSparsityFree(t->sparsity);
|
||||
t->sparsity = NULL;
|
||||
|
@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
|
||||
// This is optional. The field is NULL if a tensor is dense.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteSparsity* sparsity;
|
||||
|
||||
// Optional. Encodes shapes with unknown dimensions with -1. This field is
|
||||
// only populated when unknown dimensions exist in a read-write tensor (i.e.
|
||||
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
|
||||
// `dims_signature` contains [1, -1, -1, 3]).
|
||||
const TfLiteIntArray* dims_signature;
|
||||
} TfLiteTensor;
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
|
@ -95,6 +95,7 @@ TEST(Quantization, TestQuantizationFree) {
|
||||
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
|
||||
t.allocation_type = kTfLiteArenaRw;
|
||||
t.dims = nullptr;
|
||||
t.dims_signature = nullptr;
|
||||
t.quantization.type = kTfLiteAffineQuantization;
|
||||
t.sparsity = nullptr;
|
||||
auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||
@ -110,6 +111,7 @@ TEST(Sparsity, TestSparsityFree) {
|
||||
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
|
||||
t.allocation_type = kTfLiteArenaRw;
|
||||
t.dims = nullptr;
|
||||
t.dims_signature = nullptr;
|
||||
|
||||
// A dummy CSR sparse matrix.
|
||||
t.sparsity = static_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
|
||||
|
@ -1074,7 +1074,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
|
||||
// to Interpreter.
|
||||
TfLiteStatus Subgraph::SetTensorParametersReadWrite(
|
||||
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
||||
const int* dims, TfLiteQuantization quantization, bool is_variable) {
|
||||
const int* dims, TfLiteQuantization quantization, bool is_variable,
|
||||
const size_t rank_dims_signature, const int* dims_signature) {
|
||||
// Ensure quantization cleanup on failure.
|
||||
ScopedTfLiteQuantization scoped_quantization(&quantization);
|
||||
if (state_ == kStateInvokableAndImmutable) {
|
||||
@ -1114,6 +1115,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
|
||||
// TODO(suharshs): Update TfLiteTensorReset to include the new quantization
|
||||
// if there are other required callers.
|
||||
tensor.quantization = *scoped_quantization.release();
|
||||
tensor.dims_signature =
|
||||
ConvertArrayToTfLiteIntArray(rank_dims_signature, dims_signature);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -114,15 +114,17 @@ class Subgraph {
|
||||
inline TfLiteStatus SetTensorParametersReadWrite(
|
||||
int tensor_index, TfLiteType type, const char* name,
|
||||
const std::vector<int>& dims, TfLiteQuantization quantization,
|
||||
bool is_variable = false) {
|
||||
bool is_variable = false, const size_t rank_dims_signature = 0,
|
||||
const int* dims_signature = nullptr) {
|
||||
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
|
||||
dims.data(), quantization, is_variable);
|
||||
dims.data(), quantization, is_variable,
|
||||
rank_dims_signature, dims_signature);
|
||||
}
|
||||
TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
|
||||
const char* name, const size_t rank,
|
||||
const int* dims,
|
||||
TfLiteQuantization quantization,
|
||||
bool is_variable = false);
|
||||
TfLiteStatus SetTensorParametersReadWrite(
|
||||
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
||||
const int* dims, TfLiteQuantization quantization,
|
||||
bool is_variable = false, const size_t rank_dims_signature = 0,
|
||||
const int* dims_signature = nullptr);
|
||||
|
||||
// WARNING: Experimental interface, subject to change
|
||||
// Overrides execution plan. This bounds checks indices sent in.
|
||||
|
@ -563,6 +563,13 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
||||
status = kTfLiteError;
|
||||
}
|
||||
|
||||
size_t dims_signature_rank = 0;
|
||||
const int* dims_signature_data = nullptr;
|
||||
if (tensor->shape_signature()) {
|
||||
dims_signature_rank = tensor->shape_signature()->Length();
|
||||
dims_signature_data = tensor->shape_signature()->data();
|
||||
}
|
||||
|
||||
bool is_variable = tensor->is_variable();
|
||||
if (buffer_ptr) {
|
||||
if (is_variable) {
|
||||
@ -590,9 +597,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
||||
status = kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor),
|
||||
dims, quantization,
|
||||
is_variable) != kTfLiteOk) {
|
||||
if (subgraph->SetTensorParametersReadWrite(
|
||||
i, type, get_name(tensor), dims, quantization, is_variable,
|
||||
dims_signature_rank, dims_signature_data) != kTfLiteOk) {
|
||||
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
|
||||
i);
|
||||
status = kTfLiteError;
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco
|
||||
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
|
||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import resource_loader as _resource_loader
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||
@ -384,7 +385,16 @@ def build_toco_convert_protos(input_tensors,
|
||||
shape = input_tensor.shape
|
||||
else:
|
||||
shape = input_shapes[idx]
|
||||
input_array.shape.dims.extend(list(map(int, shape)))
|
||||
|
||||
# Create shapes with -1 for unknown dimensions.
|
||||
dims = []
|
||||
for dim in shape:
|
||||
if (dim is None or
|
||||
(isinstance(dim, tensor_shape.Dimension) and dim.value is None)):
|
||||
dims.append(-1)
|
||||
else:
|
||||
dims.append(int(dim))
|
||||
input_array.shape.dims.extend(dims)
|
||||
|
||||
for output_tensor in output_tensors:
|
||||
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
||||
|
@ -320,6 +320,7 @@ class Interpreter(object):
|
||||
tensor_index = int(tensor_index)
|
||||
tensor_name = self._interpreter.TensorName(tensor_index)
|
||||
tensor_size = self._interpreter.TensorSize(tensor_index)
|
||||
tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index)
|
||||
tensor_type = self._interpreter.TensorType(tensor_index)
|
||||
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
||||
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
|
||||
@ -332,6 +333,7 @@ class Interpreter(object):
|
||||
'name': tensor_name,
|
||||
'index': tensor_index,
|
||||
'shape': tensor_size,
|
||||
'shape_signature': tensor_size_signature,
|
||||
'dtype': tensor_type,
|
||||
'quantization': tensor_quantization,
|
||||
'quantization_parameters': {
|
||||
|
@ -301,6 +301,23 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||
|
||||
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||
const int32_t* size_signature_data = nullptr;
|
||||
int32_t size_signature_size = 0;
|
||||
if (tensor->dims_signature != nullptr) {
|
||||
size_signature_data = tensor->dims_signature->data;
|
||||
size_signature_size = tensor->dims_signature->size;
|
||||
}
|
||||
PyObject* np_array =
|
||||
PyArrayFromIntVector(size_signature_data, size_signature_size);
|
||||
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||
|
@ -69,6 +69,7 @@ class InterpreterWrapper {
|
||||
std::string TensorName(int i) const;
|
||||
PyObject* TensorType(int i) const;
|
||||
PyObject* TensorSize(int i) const;
|
||||
PyObject* TensorSizeSignature(int i) const;
|
||||
// Deprecated in favor of TensorQuantizationScales, below.
|
||||
PyObject* TensorQuantization(int i) const;
|
||||
PyObject* TensorQuantizationParameters(int i) const;
|
||||
|
@ -261,6 +261,16 @@ class TFLiteConverterBase(object):
|
||||
self.representative_dataset.input_gen, inference_input_type,
|
||||
inference_output_type, allow_float, enable_mlir_quantizer)
|
||||
|
||||
def _is_unknown_shapes_allowed(self):
|
||||
# TODO(b/128319310): Investigate which quantization methods work.
|
||||
if self._any_optimization_enabled():
|
||||
return False
|
||||
|
||||
# Unknown dimensions are only allowed with the new converter.
|
||||
if not self.experimental_new_converter:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_base_converter_args(self):
|
||||
"""Returns the base converter args.
|
||||
|
||||
@ -456,19 +466,21 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
||||
config=self._grappler_config(),
|
||||
graph=frozen_func.graph)
|
||||
|
||||
# Checks dimensions in input tensor.
|
||||
for tensor in input_tensors:
|
||||
# Note that shape_list might be empty for scalar shapes.
|
||||
shape_list = tensor.shape.as_list()
|
||||
if None in shape_list[1:]:
|
||||
raise ValueError(
|
||||
"None is only supported in the 1st dimension. Tensor '{0}' has "
|
||||
"invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list))
|
||||
elif shape_list and shape_list[0] is None:
|
||||
# Set the batch size to 1 if undefined.
|
||||
shape = tensor.shape.as_list()
|
||||
shape[0] = 1
|
||||
tensor.set_shape(shape)
|
||||
if not self._is_unknown_shapes_allowed():
|
||||
# Checks dimensions in input tensor.
|
||||
for tensor in input_tensors:
|
||||
# Note that shape_list might be empty for scalar shapes.
|
||||
shape_list = tensor.shape.as_list()
|
||||
if None in shape_list[1:]:
|
||||
raise ValueError(
|
||||
"None is only supported in the 1st dimension. Tensor '{0}' has "
|
||||
"invalid shape '{1}'.".format(
|
||||
_get_tensor_name(tensor), shape_list))
|
||||
elif shape_list and shape_list[0] is None:
|
||||
# Set the batch size to 1 if undefined.
|
||||
shape = tensor.shape.as_list()
|
||||
shape[0] = 1
|
||||
tensor.set_shape(shape)
|
||||
|
||||
self._validate_quantization()
|
||||
self._validate_representative_dataset()
|
||||
@ -942,7 +954,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
None value for dimension in input_tensor.
|
||||
"""
|
||||
# Checks dimensions in input tensor.
|
||||
if self._has_valid_tensors():
|
||||
if not self._is_unknown_shapes_allowed() and self._has_valid_tensors():
|
||||
for tensor in self._input_tensors:
|
||||
shape = tensor.shape
|
||||
if not shape:
|
||||
@ -1115,6 +1127,20 @@ class TFLiteConverter(TFLiteConverterBase):
|
||||
shape[0] = batch_size
|
||||
tensor.set_shape(shape)
|
||||
|
||||
def _is_unknown_shapes_allowed(self):
|
||||
if not super(TFLiteConverter, self)._is_unknown_shapes_allowed():
|
||||
return False
|
||||
|
||||
# `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
|
||||
# the MLIR converter.
|
||||
if self.conversion_summary_dir:
|
||||
logging.warning(
|
||||
"`conversion_summary_dir` does not work with unknown shapes. "
|
||||
"Graphs with unknown shapes might be different than when this flag "
|
||||
"is disabled.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@_tf_export(v1=["lite.TocoConverter"])
|
||||
class TocoConverter(object):
|
||||
|
@ -318,9 +318,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
# Test None as shape.
|
||||
# Test None as shape when dynamic shapes are disabled. Run with TOCO in
|
||||
# order to invoke shape checking code.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.experimental_new_converter = False
|
||||
with self.assertRaises(ValueError) as error:
|
||||
converter.convert()
|
||||
self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
|
||||
@ -375,9 +377,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
# Test invalid shape. None after 1st dimension.
|
||||
# Test invalid shape. None after 1st dimension. Run with TOCO in order to
|
||||
# invoke shape checking code.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.experimental_new_converter = False
|
||||
with self.assertRaises(ValueError) as error:
|
||||
converter.convert()
|
||||
self.assertEqual(
|
||||
@ -385,6 +389,44 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
'\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
|
||||
str(error.exception))
|
||||
|
||||
def testSizeNone(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, None, 16, 3], dtype=dtypes.float32)
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
# Test None after 1st dimension.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.experimental_new_converter = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual('Placeholder', input_details[0]['name'])
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 1, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertTrue(([1, -1, 16,
|
||||
3] == input_details[0]['shape_signature']).all())
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
# Resize tensor and invoke.
|
||||
interpreter.resize_tensor_input(0, [1, 16, 16, 3])
|
||||
interpreter.allocate_tensors()
|
||||
interpreter.invoke()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertTrue(([1, -1, 16,
|
||||
3] == input_details[0]['shape_signature']).all())
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertFalse(output_details[0]['shape_signature'])
|
||||
|
||||
def testBatchSizeValid(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
|
@ -54,12 +54,28 @@ from tensorflow.python.training.tracking import tracking
|
||||
|
||||
class TestModels(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def _evaluateTFLiteModel(self, tflite_model, input_data):
|
||||
"""Evaluates the model on the `input_data`."""
|
||||
def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
|
||||
"""Evaluates the model on the `input_data`.
|
||||
|
||||
Args:
|
||||
tflite_model: TensorFlow Lite model.
|
||||
input_data: List of EagerTensor const ops containing the input data for
|
||||
each input tensor.
|
||||
input_shapes: List of tuples representing the `shape_signature` and the
|
||||
new shape of each input tensor that has unknown dimensions.
|
||||
|
||||
Returns:
|
||||
[np.ndarray]
|
||||
"""
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
input_details = interpreter.get_input_details()
|
||||
if input_shapes:
|
||||
for idx, (shape_signature, final_shape) in enumerate(input_shapes):
|
||||
self.assertTrue(
|
||||
(input_details[idx]['shape_signature'] == shape_signature).all())
|
||||
interpreter.resize_tensor_input(idx, final_shape)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
output_details = interpreter.get_output_details()
|
||||
|
||||
for input_tensor, tensor_data in zip(input_details, input_data):
|
||||
@ -795,5 +811,62 @@ class GrapplerTest(TestModels):
|
||||
actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
||||
|
||||
|
||||
class UnknownShapes(TestModels):
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testMatMul(self):
|
||||
input_data = constant_op.constant(
|
||||
np.array(np.random.random_sample((10, 4)), dtype=np.float32))
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[None, 4], dtype=dtypes.float32)
|
||||
])
|
||||
def model(in_tensor):
|
||||
shape = array_ops.shape_v2(in_tensor)
|
||||
fill = array_ops.transpose_v2(array_ops.fill(shape, 1.))
|
||||
return math_ops.matmul(fill, in_tensor)
|
||||
|
||||
concrete_func = model.get_concrete_function()
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_new_converter = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(
|
||||
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])
|
||||
np.testing.assert_almost_equal(
|
||||
expected_value.numpy(), actual_value[0], decimal=6)
|
||||
|
||||
def testBatchMatMul(self):
|
||||
self.skipTest('BatchMatMulV2 ranked tensor check fails.')
|
||||
input_data_1 = constant_op.constant(
|
||||
np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
|
||||
input_data_2 = constant_op.constant(
|
||||
np.array(np.random.random_sample((1, 2, 256)), dtype=np.float32))
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[1, 256, 256], dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=[1, None, 256], dtype=dtypes.float32)
|
||||
])
|
||||
def model(in_tensor_1, in_tensor_2):
|
||||
return math_ops.matmul(in_tensor_1, in_tensor_2)
|
||||
|
||||
concrete_func = model.get_concrete_function()
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_new_converter = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(input_data_1, input_data_2)
|
||||
actual_value = self._evaluateTFLiteModel(
|
||||
tflite_model, [input_data_1, input_data_2],
|
||||
input_shapes={1: [1, 2, 256]})
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -178,6 +178,10 @@ table Tensor {
|
||||
// Parameters to encode a sparse tensor. See the example in
|
||||
// tensorflow/lite/testdata/sparse_tensor.json.
|
||||
sparsity:SparsityParameters; // Optional.
|
||||
|
||||
// Encodes `shape` with unknown dimensions. Unknown dimensions are
|
||||
// represented with -1.
|
||||
shape_signature:[int]; // Optional.
|
||||
}
|
||||
|
||||
// A list of builtin operators. Builtin operators are slightly faster than custom
|
||||
|
@ -3175,6 +3175,7 @@ struct TensorT : public flatbuffers::NativeTable {
|
||||
std::unique_ptr<QuantizationParametersT> quantization;
|
||||
bool is_variable;
|
||||
std::unique_ptr<SparsityParametersT> sparsity;
|
||||
std::vector<int32_t> shape_signature;
|
||||
TensorT()
|
||||
: type(TensorType_FLOAT32),
|
||||
buffer(0),
|
||||
@ -3191,7 +3192,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
VT_NAME = 10,
|
||||
VT_QUANTIZATION = 12,
|
||||
VT_IS_VARIABLE = 14,
|
||||
VT_SPARSITY = 16
|
||||
VT_SPARSITY = 16,
|
||||
VT_SHAPE_SIGNATURE = 18
|
||||
};
|
||||
const flatbuffers::Vector<int32_t> *shape() const {
|
||||
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
|
||||
@ -3214,6 +3216,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const SparsityParameters *sparsity() const {
|
||||
return GetPointer<const SparsityParameters *>(VT_SPARSITY);
|
||||
}
|
||||
const flatbuffers::Vector<int32_t> *shape_signature() const {
|
||||
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyOffset(verifier, VT_SHAPE) &&
|
||||
@ -3227,6 +3232,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
|
||||
VerifyOffset(verifier, VT_SPARSITY) &&
|
||||
verifier.VerifyTable(sparsity()) &&
|
||||
VerifyOffset(verifier, VT_SHAPE_SIGNATURE) &&
|
||||
verifier.VerifyVector(shape_signature()) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
@ -3258,6 +3265,9 @@ struct TensorBuilder {
|
||||
void add_sparsity(flatbuffers::Offset<SparsityParameters> sparsity) {
|
||||
fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity);
|
||||
}
|
||||
void add_shape_signature(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature) {
|
||||
fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature);
|
||||
}
|
||||
explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
@ -3278,8 +3288,10 @@ inline flatbuffers::Offset<Tensor> CreateTensor(
|
||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||
flatbuffers::Offset<QuantizationParameters> quantization = 0,
|
||||
bool is_variable = false,
|
||||
flatbuffers::Offset<SparsityParameters> sparsity = 0) {
|
||||
flatbuffers::Offset<SparsityParameters> sparsity = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0) {
|
||||
TensorBuilder builder_(_fbb);
|
||||
builder_.add_shape_signature(shape_signature);
|
||||
builder_.add_sparsity(sparsity);
|
||||
builder_.add_quantization(quantization);
|
||||
builder_.add_name(name);
|
||||
@ -3298,9 +3310,11 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
|
||||
const char *name = nullptr,
|
||||
flatbuffers::Offset<QuantizationParameters> quantization = 0,
|
||||
bool is_variable = false,
|
||||
flatbuffers::Offset<SparsityParameters> sparsity = 0) {
|
||||
flatbuffers::Offset<SparsityParameters> sparsity = 0,
|
||||
const std::vector<int32_t> *shape_signature = nullptr) {
|
||||
auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
|
||||
auto name__ = name ? _fbb.CreateString(name) : 0;
|
||||
auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0;
|
||||
return tflite::CreateTensor(
|
||||
_fbb,
|
||||
shape__,
|
||||
@ -3309,7 +3323,8 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
|
||||
name__,
|
||||
quantization,
|
||||
is_variable,
|
||||
sparsity);
|
||||
sparsity,
|
||||
shape_signature__);
|
||||
}
|
||||
|
||||
flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
@ -10275,6 +10290,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t
|
||||
{ auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); };
|
||||
{ auto _e = is_variable(); _o->is_variable = _e; };
|
||||
{ auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr<SparsityParametersT>(_e->UnPack(_resolver)); };
|
||||
{ auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } };
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
@ -10292,6 +10308,7 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
|
||||
auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0;
|
||||
auto _is_variable = _o->is_variable;
|
||||
auto _sparsity = _o->sparsity ? CreateSparsityParameters(_fbb, _o->sparsity.get(), _rehasher) : 0;
|
||||
auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0;
|
||||
return tflite::CreateTensor(
|
||||
_fbb,
|
||||
_shape,
|
||||
@ -10300,7 +10317,8 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
|
||||
_name,
|
||||
_quantization,
|
||||
_is_variable,
|
||||
_sparsity);
|
||||
_sparsity,
|
||||
_shape_signature);
|
||||
}
|
||||
|
||||
inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
|
@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
|
||||
// This is optional. The field is NULL if a tensor is dense.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteSparsity* sparsity;
|
||||
|
||||
// Optional. Encodes shapes with unknown dimensions with -1. This field is
|
||||
// only populated when unknown dimensions exist in a read-write tensor (i.e.
|
||||
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
|
||||
// `dims_signature` contains [1, -1, -1, 3]).
|
||||
const TfLiteIntArray* dims_signature;
|
||||
} TfLiteTensor;
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
|
Loading…
Reference in New Issue
Block a user