Add support for unknown dimensions to TFLite using MLIR converter.
PiperOrigin-RevId: 292563455 Change-Id: Ib5700cfe6faee177027329e32089abb3bcc9adaf
This commit is contained in:
parent
4e20c32249
commit
55912083e2
tensorflow
compiler/mlir/lite
lite
c
core
model.ccpython
schema
tools/benchmark/experimental/c
@ -610,6 +610,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::vector<int32_t> shape;
|
std::vector<int32_t> shape;
|
||||||
|
std::vector<int32_t> shape_signature;
|
||||||
if (type.hasStaticShape()) {
|
if (type.hasStaticShape()) {
|
||||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||||
@ -627,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||||||
|
|
||||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||||
}
|
}
|
||||||
|
} else if (type.hasRank()) {
|
||||||
|
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||||
|
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||||
|
|
||||||
|
shape.reserve(shape_ref.size());
|
||||||
|
for (auto& dim : shape_ref) {
|
||||||
|
shape.push_back(dim == -1 ? 1 : dim);
|
||||||
}
|
}
|
||||||
|
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||||
|
}
|
||||||
|
|
||||||
Type element_type = type.getElementType();
|
Type element_type = type.getElementType();
|
||||||
tflite::TensorType tflite_element_type =
|
tflite::TensorType tflite_element_type =
|
||||||
GetTFLiteType(type.getElementType()).ValueOrDie();
|
GetTFLiteType(type.getElementType()).ValueOrDie();
|
||||||
@ -664,10 +675,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (shape_signature.empty()) {
|
||||||
return tflite::CreateTensor(
|
return tflite::CreateTensor(
|
||||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||||
/*is_variable=*/is_variable);
|
/*is_variable=*/is_variable);
|
||||||
|
} else {
|
||||||
|
return tflite::CreateTensor(
|
||||||
|
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||||
|
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||||
|
/*is_variable=*/is_variable, /*sparsity=*/0,
|
||||||
|
/*shape_signature=*/builder_.CreateVector(shape_signature));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||||
|
@ -140,6 +140,11 @@ void TfLiteTensorFree(TfLiteTensor* t) {
|
|||||||
if (t->dims) TfLiteIntArrayFree(t->dims);
|
if (t->dims) TfLiteIntArrayFree(t->dims);
|
||||||
t->dims = NULL;
|
t->dims = NULL;
|
||||||
|
|
||||||
|
if (t->dims_signature) {
|
||||||
|
TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature);
|
||||||
|
}
|
||||||
|
t->dims_signature = NULL;
|
||||||
|
|
||||||
TfLiteQuantizationFree(&t->quantization);
|
TfLiteQuantizationFree(&t->quantization);
|
||||||
TfLiteSparsityFree(t->sparsity);
|
TfLiteSparsityFree(t->sparsity);
|
||||||
t->sparsity = NULL;
|
t->sparsity = NULL;
|
||||||
|
@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
|
|||||||
// This is optional. The field is NULL if a tensor is dense.
|
// This is optional. The field is NULL if a tensor is dense.
|
||||||
// WARNING: This is an experimental interface that is subject to change.
|
// WARNING: This is an experimental interface that is subject to change.
|
||||||
TfLiteSparsity* sparsity;
|
TfLiteSparsity* sparsity;
|
||||||
|
|
||||||
|
// Optional. Encodes shapes with unknown dimensions with -1. This field is
|
||||||
|
// only populated when unknown dimensions exist in a read-write tensor (i.e.
|
||||||
|
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
|
||||||
|
// `dims_signature` contains [1, -1, -1, 3]).
|
||||||
|
const TfLiteIntArray* dims_signature;
|
||||||
} TfLiteTensor;
|
} TfLiteTensor;
|
||||||
|
|
||||||
#ifndef TF_LITE_STATIC_MEMORY
|
#ifndef TF_LITE_STATIC_MEMORY
|
||||||
|
@ -95,6 +95,7 @@ TEST(Quantization, TestQuantizationFree) {
|
|||||||
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
|
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
|
||||||
t.allocation_type = kTfLiteArenaRw;
|
t.allocation_type = kTfLiteArenaRw;
|
||||||
t.dims = nullptr;
|
t.dims = nullptr;
|
||||||
|
t.dims_signature = nullptr;
|
||||||
t.quantization.type = kTfLiteAffineQuantization;
|
t.quantization.type = kTfLiteAffineQuantization;
|
||||||
t.sparsity = nullptr;
|
t.sparsity = nullptr;
|
||||||
auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
|
auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
|
||||||
@ -110,6 +111,7 @@ TEST(Sparsity, TestSparsityFree) {
|
|||||||
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
|
// Set these values, otherwise TfLiteTensorFree has uninitialized values.
|
||||||
t.allocation_type = kTfLiteArenaRw;
|
t.allocation_type = kTfLiteArenaRw;
|
||||||
t.dims = nullptr;
|
t.dims = nullptr;
|
||||||
|
t.dims_signature = nullptr;
|
||||||
|
|
||||||
// A dummy CSR sparse matrix.
|
// A dummy CSR sparse matrix.
|
||||||
t.sparsity = static_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
|
t.sparsity = static_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
|
||||||
|
@ -1074,7 +1074,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
|
|||||||
// to Interpreter.
|
// to Interpreter.
|
||||||
TfLiteStatus Subgraph::SetTensorParametersReadWrite(
|
TfLiteStatus Subgraph::SetTensorParametersReadWrite(
|
||||||
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
||||||
const int* dims, TfLiteQuantization quantization, bool is_variable) {
|
const int* dims, TfLiteQuantization quantization, bool is_variable,
|
||||||
|
const size_t rank_dims_signature, const int* dims_signature) {
|
||||||
// Ensure quantization cleanup on failure.
|
// Ensure quantization cleanup on failure.
|
||||||
ScopedTfLiteQuantization scoped_quantization(&quantization);
|
ScopedTfLiteQuantization scoped_quantization(&quantization);
|
||||||
if (state_ == kStateInvokableAndImmutable) {
|
if (state_ == kStateInvokableAndImmutable) {
|
||||||
@ -1114,6 +1115,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
|
|||||||
// TODO(suharshs): Update TfLiteTensorReset to include the new quantization
|
// TODO(suharshs): Update TfLiteTensorReset to include the new quantization
|
||||||
// if there are other required callers.
|
// if there are other required callers.
|
||||||
tensor.quantization = *scoped_quantization.release();
|
tensor.quantization = *scoped_quantization.release();
|
||||||
|
tensor.dims_signature =
|
||||||
|
ConvertArrayToTfLiteIntArray(rank_dims_signature, dims_signature);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,15 +114,17 @@ class Subgraph {
|
|||||||
inline TfLiteStatus SetTensorParametersReadWrite(
|
inline TfLiteStatus SetTensorParametersReadWrite(
|
||||||
int tensor_index, TfLiteType type, const char* name,
|
int tensor_index, TfLiteType type, const char* name,
|
||||||
const std::vector<int>& dims, TfLiteQuantization quantization,
|
const std::vector<int>& dims, TfLiteQuantization quantization,
|
||||||
bool is_variable = false) {
|
bool is_variable = false, const size_t rank_dims_signature = 0,
|
||||||
|
const int* dims_signature = nullptr) {
|
||||||
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
|
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
|
||||||
dims.data(), quantization, is_variable);
|
dims.data(), quantization, is_variable,
|
||||||
|
rank_dims_signature, dims_signature);
|
||||||
}
|
}
|
||||||
TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
|
TfLiteStatus SetTensorParametersReadWrite(
|
||||||
const char* name, const size_t rank,
|
int tensor_index, TfLiteType type, const char* name, const size_t rank,
|
||||||
const int* dims,
|
const int* dims, TfLiteQuantization quantization,
|
||||||
TfLiteQuantization quantization,
|
bool is_variable = false, const size_t rank_dims_signature = 0,
|
||||||
bool is_variable = false);
|
const int* dims_signature = nullptr);
|
||||||
|
|
||||||
// WARNING: Experimental interface, subject to change
|
// WARNING: Experimental interface, subject to change
|
||||||
// Overrides execution plan. This bounds checks indices sent in.
|
// Overrides execution plan. This bounds checks indices sent in.
|
||||||
|
@ -563,6 +563,13 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
|||||||
status = kTfLiteError;
|
status = kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t dims_signature_rank = 0;
|
||||||
|
const int* dims_signature_data = nullptr;
|
||||||
|
if (tensor->shape_signature()) {
|
||||||
|
dims_signature_rank = tensor->shape_signature()->Length();
|
||||||
|
dims_signature_data = tensor->shape_signature()->data();
|
||||||
|
}
|
||||||
|
|
||||||
bool is_variable = tensor->is_variable();
|
bool is_variable = tensor->is_variable();
|
||||||
if (buffer_ptr) {
|
if (buffer_ptr) {
|
||||||
if (is_variable) {
|
if (is_variable) {
|
||||||
@ -590,9 +597,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
|
|||||||
status = kTfLiteError;
|
status = kTfLiteError;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor),
|
if (subgraph->SetTensorParametersReadWrite(
|
||||||
dims, quantization,
|
i, type, get_name(tensor), dims, quantization, is_variable,
|
||||||
is_variable) != kTfLiteOk) {
|
dims_signature_rank, dims_signature_data) != kTfLiteOk) {
|
||||||
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
|
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
|
||||||
i);
|
i);
|
||||||
status = kTfLiteError;
|
status = kTfLiteError;
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco
|
|||||||
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
|
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
|
||||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
||||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.platform import resource_loader as _resource_loader
|
from tensorflow.python.platform import resource_loader as _resource_loader
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
||||||
@ -384,7 +385,16 @@ def build_toco_convert_protos(input_tensors,
|
|||||||
shape = input_tensor.shape
|
shape = input_tensor.shape
|
||||||
else:
|
else:
|
||||||
shape = input_shapes[idx]
|
shape = input_shapes[idx]
|
||||||
input_array.shape.dims.extend(list(map(int, shape)))
|
|
||||||
|
# Create shapes with -1 for unknown dimensions.
|
||||||
|
dims = []
|
||||||
|
for dim in shape:
|
||||||
|
if (dim is None or
|
||||||
|
(isinstance(dim, tensor_shape.Dimension) and dim.value is None)):
|
||||||
|
dims.append(-1)
|
||||||
|
else:
|
||||||
|
dims.append(int(dim))
|
||||||
|
input_array.shape.dims.extend(dims)
|
||||||
|
|
||||||
for output_tensor in output_tensors:
|
for output_tensor in output_tensors:
|
||||||
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
model.output_arrays.append(util.get_tensor_name(output_tensor))
|
||||||
|
@ -320,6 +320,7 @@ class Interpreter(object):
|
|||||||
tensor_index = int(tensor_index)
|
tensor_index = int(tensor_index)
|
||||||
tensor_name = self._interpreter.TensorName(tensor_index)
|
tensor_name = self._interpreter.TensorName(tensor_index)
|
||||||
tensor_size = self._interpreter.TensorSize(tensor_index)
|
tensor_size = self._interpreter.TensorSize(tensor_index)
|
||||||
|
tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index)
|
||||||
tensor_type = self._interpreter.TensorType(tensor_index)
|
tensor_type = self._interpreter.TensorType(tensor_index)
|
||||||
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
|
||||||
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
|
tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
|
||||||
@ -332,6 +333,7 @@ class Interpreter(object):
|
|||||||
'name': tensor_name,
|
'name': tensor_name,
|
||||||
'index': tensor_index,
|
'index': tensor_index,
|
||||||
'shape': tensor_size,
|
'shape': tensor_size,
|
||||||
|
'shape_signature': tensor_size_signature,
|
||||||
'dtype': tensor_type,
|
'dtype': tensor_type,
|
||||||
'quantization': tensor_quantization,
|
'quantization': tensor_quantization,
|
||||||
'quantization_parameters': {
|
'quantization_parameters': {
|
||||||
|
@ -301,6 +301,23 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
|
|||||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
|
||||||
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
|
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||||
|
|
||||||
|
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||||
|
const int32_t* size_signature_data = nullptr;
|
||||||
|
int32_t size_signature_size = 0;
|
||||||
|
if (tensor->dims_signature != nullptr) {
|
||||||
|
size_signature_data = tensor->dims_signature->data;
|
||||||
|
size_signature_size = tensor->dims_signature->size;
|
||||||
|
}
|
||||||
|
PyObject* np_array =
|
||||||
|
PyArrayFromIntVector(size_signature_data, size_signature_size);
|
||||||
|
|
||||||
|
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
|
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
|
||||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
|
||||||
|
@ -69,6 +69,7 @@ class InterpreterWrapper {
|
|||||||
std::string TensorName(int i) const;
|
std::string TensorName(int i) const;
|
||||||
PyObject* TensorType(int i) const;
|
PyObject* TensorType(int i) const;
|
||||||
PyObject* TensorSize(int i) const;
|
PyObject* TensorSize(int i) const;
|
||||||
|
PyObject* TensorSizeSignature(int i) const;
|
||||||
// Deprecated in favor of TensorQuantizationScales, below.
|
// Deprecated in favor of TensorQuantizationScales, below.
|
||||||
PyObject* TensorQuantization(int i) const;
|
PyObject* TensorQuantization(int i) const;
|
||||||
PyObject* TensorQuantizationParameters(int i) const;
|
PyObject* TensorQuantizationParameters(int i) const;
|
||||||
|
@ -261,6 +261,16 @@ class TFLiteConverterBase(object):
|
|||||||
self.representative_dataset.input_gen, inference_input_type,
|
self.representative_dataset.input_gen, inference_input_type,
|
||||||
inference_output_type, allow_float, enable_mlir_quantizer)
|
inference_output_type, allow_float, enable_mlir_quantizer)
|
||||||
|
|
||||||
|
def _is_unknown_shapes_allowed(self):
|
||||||
|
# TODO(b/128319310): Investigate which quantization methods work.
|
||||||
|
if self._any_optimization_enabled():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Unknown dimensions are only allowed with the new converter.
|
||||||
|
if not self.experimental_new_converter:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _get_base_converter_args(self):
|
def _get_base_converter_args(self):
|
||||||
"""Returns the base converter args.
|
"""Returns the base converter args.
|
||||||
|
|
||||||
@ -456,6 +466,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
config=self._grappler_config(),
|
config=self._grappler_config(),
|
||||||
graph=frozen_func.graph)
|
graph=frozen_func.graph)
|
||||||
|
|
||||||
|
if not self._is_unknown_shapes_allowed():
|
||||||
# Checks dimensions in input tensor.
|
# Checks dimensions in input tensor.
|
||||||
for tensor in input_tensors:
|
for tensor in input_tensors:
|
||||||
# Note that shape_list might be empty for scalar shapes.
|
# Note that shape_list might be empty for scalar shapes.
|
||||||
@ -463,7 +474,8 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
if None in shape_list[1:]:
|
if None in shape_list[1:]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"None is only supported in the 1st dimension. Tensor '{0}' has "
|
"None is only supported in the 1st dimension. Tensor '{0}' has "
|
||||||
"invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list))
|
"invalid shape '{1}'.".format(
|
||||||
|
_get_tensor_name(tensor), shape_list))
|
||||||
elif shape_list and shape_list[0] is None:
|
elif shape_list and shape_list[0] is None:
|
||||||
# Set the batch size to 1 if undefined.
|
# Set the batch size to 1 if undefined.
|
||||||
shape = tensor.shape.as_list()
|
shape = tensor.shape.as_list()
|
||||||
@ -942,7 +954,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
None value for dimension in input_tensor.
|
None value for dimension in input_tensor.
|
||||||
"""
|
"""
|
||||||
# Checks dimensions in input tensor.
|
# Checks dimensions in input tensor.
|
||||||
if self._has_valid_tensors():
|
if not self._is_unknown_shapes_allowed() and self._has_valid_tensors():
|
||||||
for tensor in self._input_tensors:
|
for tensor in self._input_tensors:
|
||||||
shape = tensor.shape
|
shape = tensor.shape
|
||||||
if not shape:
|
if not shape:
|
||||||
@ -1115,6 +1127,20 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
shape[0] = batch_size
|
shape[0] = batch_size
|
||||||
tensor.set_shape(shape)
|
tensor.set_shape(shape)
|
||||||
|
|
||||||
|
def _is_unknown_shapes_allowed(self):
|
||||||
|
if not super(TFLiteConverter, self)._is_unknown_shapes_allowed():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by
|
||||||
|
# the MLIR converter.
|
||||||
|
if self.conversion_summary_dir:
|
||||||
|
logging.warning(
|
||||||
|
"`conversion_summary_dir` does not work with unknown shapes. "
|
||||||
|
"Graphs with unknown shapes might be different than when this flag "
|
||||||
|
"is disabled.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@_tf_export(v1=["lite.TocoConverter"])
|
@_tf_export(v1=["lite.TocoConverter"])
|
||||||
class TocoConverter(object):
|
class TocoConverter(object):
|
||||||
|
@ -318,9 +318,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
out_tensor = in_tensor + in_tensor
|
out_tensor = in_tensor + in_tensor
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
|
|
||||||
# Test None as shape.
|
# Test None as shape when dynamic shapes are disabled. Run with TOCO in
|
||||||
|
# order to invoke shape checking code.
|
||||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||||
[out_tensor])
|
[out_tensor])
|
||||||
|
converter.experimental_new_converter = False
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
converter.convert()
|
converter.convert()
|
||||||
self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
|
self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
|
||||||
@ -375,9 +377,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
out_tensor = in_tensor + in_tensor
|
out_tensor = in_tensor + in_tensor
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
|
|
||||||
# Test invalid shape. None after 1st dimension.
|
# Test invalid shape. None after 1st dimension. Run with TOCO in order to
|
||||||
|
# invoke shape checking code.
|
||||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||||
[out_tensor])
|
[out_tensor])
|
||||||
|
converter.experimental_new_converter = False
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
converter.convert()
|
converter.convert()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -385,6 +389,44 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
|||||||
'\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
|
'\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
|
||||||
str(error.exception))
|
str(error.exception))
|
||||||
|
|
||||||
|
def testSizeNone(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
in_tensor = array_ops.placeholder(
|
||||||
|
shape=[1, None, 16, 3], dtype=dtypes.float32)
|
||||||
|
out_tensor = in_tensor + in_tensor
|
||||||
|
sess = session.Session()
|
||||||
|
|
||||||
|
# Test None after 1st dimension.
|
||||||
|
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||||
|
[out_tensor])
|
||||||
|
converter.experimental_new_converter = True
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
# Check values from converted model.
|
||||||
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
self.assertLen(input_details, 1)
|
||||||
|
self.assertEqual('Placeholder', input_details[0]['name'])
|
||||||
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
|
self.assertTrue(([1, 1, 16, 3] == input_details[0]['shape']).all())
|
||||||
|
self.assertTrue(([1, -1, 16,
|
||||||
|
3] == input_details[0]['shape_signature']).all())
|
||||||
|
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||||
|
|
||||||
|
# Resize tensor and invoke.
|
||||||
|
interpreter.resize_tensor_input(0, [1, 16, 16, 3])
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
interpreter.invoke()
|
||||||
|
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
self.assertLen(input_details, 1)
|
||||||
|
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||||
|
self.assertTrue(([1, -1, 16,
|
||||||
|
3] == input_details[0]['shape_signature']).all())
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
self.assertFalse(output_details[0]['shape_signature'])
|
||||||
|
|
||||||
def testBatchSizeValid(self):
|
def testBatchSizeValid(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
in_tensor = array_ops.placeholder(
|
in_tensor = array_ops.placeholder(
|
||||||
|
@ -54,12 +54,28 @@ from tensorflow.python.training.tracking import tracking
|
|||||||
|
|
||||||
class TestModels(test_util.TensorFlowTestCase, parameterized.TestCase):
|
class TestModels(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def _evaluateTFLiteModel(self, tflite_model, input_data):
|
def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
|
||||||
"""Evaluates the model on the `input_data`."""
|
"""Evaluates the model on the `input_data`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_model: TensorFlow Lite model.
|
||||||
|
input_data: List of EagerTensor const ops containing the input data for
|
||||||
|
each input tensor.
|
||||||
|
input_shapes: List of tuples representing the `shape_signature` and the
|
||||||
|
new shape of each input tensor that has unknown dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[np.ndarray]
|
||||||
|
"""
|
||||||
interpreter = Interpreter(model_content=tflite_model)
|
interpreter = Interpreter(model_content=tflite_model)
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
if input_shapes:
|
||||||
|
for idx, (shape_signature, final_shape) in enumerate(input_shapes):
|
||||||
|
self.assertTrue(
|
||||||
|
(input_details[idx]['shape_signature'] == shape_signature).all())
|
||||||
|
interpreter.resize_tensor_input(idx, final_shape)
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
input_details = interpreter.get_input_details()
|
|
||||||
output_details = interpreter.get_output_details()
|
output_details = interpreter.get_output_details()
|
||||||
|
|
||||||
for input_tensor, tensor_data in zip(input_details, input_data):
|
for input_tensor, tensor_data in zip(input_details, input_data):
|
||||||
@ -795,5 +811,62 @@ class GrapplerTest(TestModels):
|
|||||||
actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
|
actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
|
||||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownShapes(TestModels):
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testMatMul(self):
|
||||||
|
input_data = constant_op.constant(
|
||||||
|
np.array(np.random.random_sample((10, 4)), dtype=np.float32))
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=[None, 4], dtype=dtypes.float32)
|
||||||
|
])
|
||||||
|
def model(in_tensor):
|
||||||
|
shape = array_ops.shape_v2(in_tensor)
|
||||||
|
fill = array_ops.transpose_v2(array_ops.fill(shape, 1.))
|
||||||
|
return math_ops.matmul(fill, in_tensor)
|
||||||
|
|
||||||
|
concrete_func = model.get_concrete_function()
|
||||||
|
|
||||||
|
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||||
|
converter.experimental_new_converter = True
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
# Check values from converted model.
|
||||||
|
expected_value = concrete_func(input_data)
|
||||||
|
actual_value = self._evaluateTFLiteModel(
|
||||||
|
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
expected_value.numpy(), actual_value[0], decimal=6)
|
||||||
|
|
||||||
|
def testBatchMatMul(self):
|
||||||
|
self.skipTest('BatchMatMulV2 ranked tensor check fails.')
|
||||||
|
input_data_1 = constant_op.constant(
|
||||||
|
np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
|
||||||
|
input_data_2 = constant_op.constant(
|
||||||
|
np.array(np.random.random_sample((1, 2, 256)), dtype=np.float32))
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=[1, 256, 256], dtype=dtypes.float32),
|
||||||
|
tensor_spec.TensorSpec(shape=[1, None, 256], dtype=dtypes.float32)
|
||||||
|
])
|
||||||
|
def model(in_tensor_1, in_tensor_2):
|
||||||
|
return math_ops.matmul(in_tensor_1, in_tensor_2)
|
||||||
|
|
||||||
|
concrete_func = model.get_concrete_function()
|
||||||
|
|
||||||
|
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||||
|
converter.experimental_new_converter = True
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
# Check values from converted model.
|
||||||
|
expected_value = concrete_func(input_data_1, input_data_2)
|
||||||
|
actual_value = self._evaluateTFLiteModel(
|
||||||
|
tflite_model, [input_data_1, input_data_2],
|
||||||
|
input_shapes={1: [1, 2, 256]})
|
||||||
|
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -178,6 +178,10 @@ table Tensor {
|
|||||||
// Parameters to encode a sparse tensor. See the example in
|
// Parameters to encode a sparse tensor. See the example in
|
||||||
// tensorflow/lite/testdata/sparse_tensor.json.
|
// tensorflow/lite/testdata/sparse_tensor.json.
|
||||||
sparsity:SparsityParameters; // Optional.
|
sparsity:SparsityParameters; // Optional.
|
||||||
|
|
||||||
|
// Encodes `shape` with unknown dimensions. Unknown dimensions are
|
||||||
|
// represented with -1.
|
||||||
|
shape_signature:[int]; // Optional.
|
||||||
}
|
}
|
||||||
|
|
||||||
// A list of builtin operators. Builtin operators are slightly faster than custom
|
// A list of builtin operators. Builtin operators are slightly faster than custom
|
||||||
|
@ -3175,6 +3175,7 @@ struct TensorT : public flatbuffers::NativeTable {
|
|||||||
std::unique_ptr<QuantizationParametersT> quantization;
|
std::unique_ptr<QuantizationParametersT> quantization;
|
||||||
bool is_variable;
|
bool is_variable;
|
||||||
std::unique_ptr<SparsityParametersT> sparsity;
|
std::unique_ptr<SparsityParametersT> sparsity;
|
||||||
|
std::vector<int32_t> shape_signature;
|
||||||
TensorT()
|
TensorT()
|
||||||
: type(TensorType_FLOAT32),
|
: type(TensorType_FLOAT32),
|
||||||
buffer(0),
|
buffer(0),
|
||||||
@ -3191,7 +3192,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
VT_NAME = 10,
|
VT_NAME = 10,
|
||||||
VT_QUANTIZATION = 12,
|
VT_QUANTIZATION = 12,
|
||||||
VT_IS_VARIABLE = 14,
|
VT_IS_VARIABLE = 14,
|
||||||
VT_SPARSITY = 16
|
VT_SPARSITY = 16,
|
||||||
|
VT_SHAPE_SIGNATURE = 18
|
||||||
};
|
};
|
||||||
const flatbuffers::Vector<int32_t> *shape() const {
|
const flatbuffers::Vector<int32_t> *shape() const {
|
||||||
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
|
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
|
||||||
@ -3214,6 +3216,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
const SparsityParameters *sparsity() const {
|
const SparsityParameters *sparsity() const {
|
||||||
return GetPointer<const SparsityParameters *>(VT_SPARSITY);
|
return GetPointer<const SparsityParameters *>(VT_SPARSITY);
|
||||||
}
|
}
|
||||||
|
const flatbuffers::Vector<int32_t> *shape_signature() const {
|
||||||
|
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE);
|
||||||
|
}
|
||||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
return VerifyTableStart(verifier) &&
|
return VerifyTableStart(verifier) &&
|
||||||
VerifyOffset(verifier, VT_SHAPE) &&
|
VerifyOffset(verifier, VT_SHAPE) &&
|
||||||
@ -3227,6 +3232,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
|
VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
|
||||||
VerifyOffset(verifier, VT_SPARSITY) &&
|
VerifyOffset(verifier, VT_SPARSITY) &&
|
||||||
verifier.VerifyTable(sparsity()) &&
|
verifier.VerifyTable(sparsity()) &&
|
||||||
|
VerifyOffset(verifier, VT_SHAPE_SIGNATURE) &&
|
||||||
|
verifier.VerifyVector(shape_signature()) &&
|
||||||
verifier.EndTable();
|
verifier.EndTable();
|
||||||
}
|
}
|
||||||
TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
@ -3258,6 +3265,9 @@ struct TensorBuilder {
|
|||||||
void add_sparsity(flatbuffers::Offset<SparsityParameters> sparsity) {
|
void add_sparsity(flatbuffers::Offset<SparsityParameters> sparsity) {
|
||||||
fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity);
|
fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity);
|
||||||
}
|
}
|
||||||
|
void add_shape_signature(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature) {
|
||||||
|
fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature);
|
||||||
|
}
|
||||||
explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
: fbb_(_fbb) {
|
: fbb_(_fbb) {
|
||||||
start_ = fbb_.StartTable();
|
start_ = fbb_.StartTable();
|
||||||
@ -3278,8 +3288,10 @@ inline flatbuffers::Offset<Tensor> CreateTensor(
|
|||||||
flatbuffers::Offset<flatbuffers::String> name = 0,
|
flatbuffers::Offset<flatbuffers::String> name = 0,
|
||||||
flatbuffers::Offset<QuantizationParameters> quantization = 0,
|
flatbuffers::Offset<QuantizationParameters> quantization = 0,
|
||||||
bool is_variable = false,
|
bool is_variable = false,
|
||||||
flatbuffers::Offset<SparsityParameters> sparsity = 0) {
|
flatbuffers::Offset<SparsityParameters> sparsity = 0,
|
||||||
|
flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0) {
|
||||||
TensorBuilder builder_(_fbb);
|
TensorBuilder builder_(_fbb);
|
||||||
|
builder_.add_shape_signature(shape_signature);
|
||||||
builder_.add_sparsity(sparsity);
|
builder_.add_sparsity(sparsity);
|
||||||
builder_.add_quantization(quantization);
|
builder_.add_quantization(quantization);
|
||||||
builder_.add_name(name);
|
builder_.add_name(name);
|
||||||
@ -3298,9 +3310,11 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
|
|||||||
const char *name = nullptr,
|
const char *name = nullptr,
|
||||||
flatbuffers::Offset<QuantizationParameters> quantization = 0,
|
flatbuffers::Offset<QuantizationParameters> quantization = 0,
|
||||||
bool is_variable = false,
|
bool is_variable = false,
|
||||||
flatbuffers::Offset<SparsityParameters> sparsity = 0) {
|
flatbuffers::Offset<SparsityParameters> sparsity = 0,
|
||||||
|
const std::vector<int32_t> *shape_signature = nullptr) {
|
||||||
auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
|
auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
|
||||||
auto name__ = name ? _fbb.CreateString(name) : 0;
|
auto name__ = name ? _fbb.CreateString(name) : 0;
|
||||||
|
auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0;
|
||||||
return tflite::CreateTensor(
|
return tflite::CreateTensor(
|
||||||
_fbb,
|
_fbb,
|
||||||
shape__,
|
shape__,
|
||||||
@ -3309,7 +3323,8 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
|
|||||||
name__,
|
name__,
|
||||||
quantization,
|
quantization,
|
||||||
is_variable,
|
is_variable,
|
||||||
sparsity);
|
sparsity,
|
||||||
|
shape_signature__);
|
||||||
}
|
}
|
||||||
|
|
||||||
flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
@ -10275,6 +10290,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t
|
|||||||
{ auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); };
|
{ auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); };
|
||||||
{ auto _e = is_variable(); _o->is_variable = _e; };
|
{ auto _e = is_variable(); _o->is_variable = _e; };
|
||||||
{ auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr<SparsityParametersT>(_e->UnPack(_resolver)); };
|
{ auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr<SparsityParametersT>(_e->UnPack(_resolver)); };
|
||||||
|
{ auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } };
|
||||||
}
|
}
|
||||||
|
|
||||||
inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
@ -10292,6 +10308,7 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
|
|||||||
auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0;
|
auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0;
|
||||||
auto _is_variable = _o->is_variable;
|
auto _is_variable = _o->is_variable;
|
||||||
auto _sparsity = _o->sparsity ? CreateSparsityParameters(_fbb, _o->sparsity.get(), _rehasher) : 0;
|
auto _sparsity = _o->sparsity ? CreateSparsityParameters(_fbb, _o->sparsity.get(), _rehasher) : 0;
|
||||||
|
auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0;
|
||||||
return tflite::CreateTensor(
|
return tflite::CreateTensor(
|
||||||
_fbb,
|
_fbb,
|
||||||
_shape,
|
_shape,
|
||||||
@ -10300,7 +10317,8 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
|
|||||||
_name,
|
_name,
|
||||||
_quantization,
|
_quantization,
|
||||||
_is_variable,
|
_is_variable,
|
||||||
_sparsity);
|
_sparsity,
|
||||||
|
_shape_signature);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
@ -391,6 +391,12 @@ typedef struct TfLiteTensor {
|
|||||||
// This is optional. The field is NULL if a tensor is dense.
|
// This is optional. The field is NULL if a tensor is dense.
|
||||||
// WARNING: This is an experimental interface that is subject to change.
|
// WARNING: This is an experimental interface that is subject to change.
|
||||||
TfLiteSparsity* sparsity;
|
TfLiteSparsity* sparsity;
|
||||||
|
|
||||||
|
// Optional. Encodes shapes with unknown dimensions with -1. This field is
|
||||||
|
// only populated when unknown dimensions exist in a read-write tensor (i.e.
|
||||||
|
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
|
||||||
|
// `dims_signature` contains [1, -1, -1, 3]).
|
||||||
|
const TfLiteIntArray* dims_signature;
|
||||||
} TfLiteTensor;
|
} TfLiteTensor;
|
||||||
|
|
||||||
#ifndef TF_LITE_STATIC_MEMORY
|
#ifndef TF_LITE_STATIC_MEMORY
|
||||||
|
Loading…
Reference in New Issue
Block a user