From 55912083e2f16087c2f29394acf8a6a4811a2ce0 Mon Sep 17 00:00:00 2001 From: Nupur Garg <nupurgarg@google.com> Date: Fri, 31 Jan 2020 09:53:26 -0800 Subject: [PATCH] Add support for unknown dimensions to TFLite using MLIR converter. PiperOrigin-RevId: 292563455 Change-Id: Ib5700cfe6faee177027329e32089abb3bcc9adaf --- .../mlir/lite/flatbuffer_translate.cc | 28 ++++++- tensorflow/lite/c/common.c | 5 ++ tensorflow/lite/c/common.h | 6 ++ tensorflow/lite/c/common_test.cc | 2 + tensorflow/lite/core/subgraph.cc | 5 +- tensorflow/lite/core/subgraph.h | 16 ++-- tensorflow/lite/model.cc | 13 ++- tensorflow/lite/python/convert.py | 12 ++- tensorflow/lite/python/interpreter.py | 2 + .../interpreter_wrapper.cc | 17 ++++ .../interpreter_wrapper/interpreter_wrapper.h | 1 + tensorflow/lite/python/lite.py | 54 +++++++++---- tensorflow/lite/python/lite_test.py | 46 ++++++++++- tensorflow/lite/python/lite_v2_test.py | 79 ++++++++++++++++++- tensorflow/lite/schema/schema.fbs | 4 + tensorflow/lite/schema/schema_generated.h | 28 +++++-- .../benchmark/experimental/c/c_api_types.h | 6 ++ 17 files changed, 284 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 14e99ce76f8..7b909c0c857 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -610,6 +610,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor( }; std::vector<int32_t> shape; + std::vector<int32_t> shape_signature; if (type.hasStaticShape()) { llvm::ArrayRef<int64_t> shape_ref = type.getShape(); if (mlir::failed(check_shape(shape_ref))) return llvm::None; @@ -627,7 +628,17 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor( shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); } + } else if (type.hasRank()) { + llvm::ArrayRef<int64_t> shape_ref = type.getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; + + shape.reserve(shape_ref.size()); + for (auto& dim : shape_ref) { + shape.push_back(dim == -1 ? 1 : dim); + } + shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); } + Type element_type = type.getElementType(); tflite::TensorType tflite_element_type = GetTFLiteType(type.getElementType()).ValueOrDie(); @@ -664,10 +675,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor( break; } } - return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, - (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable); + + if (shape_signature.empty()) { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable); + } else { + return tflite::CreateTensor( + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable, /*sparsity=*/0, + /*shape_signature=*/builder_.CreateVector(shape_signature)); + } } BufferOffset<tflite::Operator> Translator::BuildIfOperator( diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index 1721e75d7ce..7196f32b62a 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -140,6 +140,11 @@ void TfLiteTensorFree(TfLiteTensor* t) { if (t->dims) TfLiteIntArrayFree(t->dims); t->dims = NULL; + if (t->dims_signature) { + TfLiteIntArrayFree((TfLiteIntArray *) t->dims_signature); + } + t->dims_signature = NULL; + TfLiteQuantizationFree(&t->quantization); TfLiteSparsityFree(t->sparsity); t->sparsity = NULL; diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 4d7fe8c78a8..023e1871d2b 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -391,6 +391,12 @@ typedef struct TfLiteTensor { // This is optional. The field is NULL if a tensor is dense. // WARNING: This is an experimental interface that is subject to change. TfLiteSparsity* sparsity; + + // Optional. Encodes shapes with unknown dimensions with -1. This field is + // only populated when unknown dimensions exist in a read-write tensor (i.e. + // an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and + // `dims_signature` contains [1, -1, -1, 3]). + const TfLiteIntArray* dims_signature; } TfLiteTensor; #ifndef TF_LITE_STATIC_MEMORY diff --git a/tensorflow/lite/c/common_test.cc b/tensorflow/lite/c/common_test.cc index 7230adff0e9..0421b50c05e 100644 --- a/tensorflow/lite/c/common_test.cc +++ b/tensorflow/lite/c/common_test.cc @@ -95,6 +95,7 @@ TEST(Quantization, TestQuantizationFree) { // Set these values, otherwise TfLiteTensorFree has uninitialized values. t.allocation_type = kTfLiteArenaRw; t.dims = nullptr; + t.dims_signature = nullptr; t.quantization.type = kTfLiteAffineQuantization; t.sparsity = nullptr; auto* params = reinterpret_cast<TfLiteAffineQuantization*>( @@ -110,6 +111,7 @@ TEST(Sparsity, TestSparsityFree) { // Set these values, otherwise TfLiteTensorFree has uninitialized values. t.allocation_type = kTfLiteArenaRw; t.dims = nullptr; + t.dims_signature = nullptr; // A dummy CSR sparse matrix. t.sparsity = static_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity))); diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 0f94ca0ae3c..e49b9ad9a59 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1074,7 +1074,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( // to Interpreter. TfLiteStatus Subgraph::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantization quantization, bool is_variable) { + const int* dims, TfLiteQuantization quantization, bool is_variable, + const size_t rank_dims_signature, const int* dims_signature) { // Ensure quantization cleanup on failure. ScopedTfLiteQuantization scoped_quantization(&quantization); if (state_ == kStateInvokableAndImmutable) { @@ -1114,6 +1115,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite( // TODO(suharshs): Update TfLiteTensorReset to include the new quantization // if there are other required callers. tensor.quantization = *scoped_quantization.release(); + tensor.dims_signature = + ConvertArrayToTfLiteIntArray(rank_dims_signature, dims_signature); return kTfLiteOk; } diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 58c125a5f98..021439e827b 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -114,15 +114,17 @@ class Subgraph { inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const std::vector<int>& dims, TfLiteQuantization quantization, - bool is_variable = false) { + bool is_variable = false, const size_t rank_dims_signature = 0, + const int* dims_signature = nullptr) { return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), - dims.data(), quantization, is_variable); + dims.data(), quantization, is_variable, + rank_dims_signature, dims_signature); } - TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, - const char* name, const size_t rank, - const int* dims, - TfLiteQuantization quantization, - bool is_variable = false); + TfLiteStatus SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, const size_t rank, + const int* dims, TfLiteQuantization quantization, + bool is_variable = false, const size_t rank_dims_signature = 0, + const int* dims_signature = nullptr); // WARNING: Experimental interface, subject to change // Overrides execution plan. This bounds checks indices sent in. diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc index 0556f47adba..04d064d0933 100644 --- a/tensorflow/lite/model.cc +++ b/tensorflow/lite/model.cc @@ -563,6 +563,13 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } + size_t dims_signature_rank = 0; + const int* dims_signature_data = nullptr; + if (tensor->shape_signature()) { + dims_signature_rank = tensor->shape_signature()->Length(); + dims_signature_data = tensor->shape_signature()->data(); + } + bool is_variable = tensor->is_variable(); if (buffer_ptr) { if (is_variable) { @@ -590,9 +597,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } } else { - if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor), - dims, quantization, - is_variable) != kTfLiteOk) { + if (subgraph->SetTensorParametersReadWrite( + i, type, get_name(tensor), dims, quantization, is_variable, + dims_signature_rank, dims_signature_data) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", i); status = kTfLiteError; diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 2fe4d172487..4813edef126 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export @@ -384,7 +385,16 @@ def build_toco_convert_protos(input_tensors, shape = input_tensor.shape else: shape = input_shapes[idx] - input_array.shape.dims.extend(list(map(int, shape))) + + # Create shapes with -1 for unknown dimensions. + dims = [] + for dim in shape: + if (dim is None or + (isinstance(dim, tensor_shape.Dimension) and dim.value is None)): + dims.append(-1) + else: + dims.append(int(dim)) + input_array.shape.dims.extend(dims) for output_tensor in output_tensors: model.output_arrays.append(util.get_tensor_name(output_tensor)) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 153b6f17c3c..4acedabeab9 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -320,6 +320,7 @@ class Interpreter(object): tensor_index = int(tensor_index) tensor_name = self._interpreter.TensorName(tensor_index) tensor_size = self._interpreter.TensorSize(tensor_index) + tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index) tensor_type = self._interpreter.TensorType(tensor_index) tensor_quantization = self._interpreter.TensorQuantization(tensor_index) tensor_quantization_params = self._interpreter.TensorQuantizationParameters( @@ -332,6 +333,7 @@ class Interpreter(object): 'name': tensor_name, 'index': tensor_index, 'shape': tensor_size, + 'shape_signature': tensor_size_signature, 'dtype': tensor_type, 'quantization': tensor_quantization, 'quantization_parameters': { diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 6b1bf34ea7d..58fb17e4f9b 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -301,6 +301,23 @@ PyObject* InterpreterWrapper::TensorSize(int i) const { return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); } +PyObject* InterpreterWrapper::TensorSizeSignature(int i) const { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); + + const TfLiteTensor* tensor = interpreter_->tensor(i); + const int32_t* size_signature_data = nullptr; + int32_t size_signature_size = 0; + if (tensor->dims_signature != nullptr) { + size_signature_data = tensor->dims_signature->data; + size_signature_size = tensor->dims_signature->size; + } + PyObject* np_array = + PyArrayFromIntVector(size_signature_data, size_signature_size); + + return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array)); +} + PyObject* InterpreterWrapper::TensorQuantization(int i) const { TFLITE_PY_ENSURE_VALID_INTERPRETER(); TFLITE_PY_TENSOR_BOUNDS_CHECK(i); diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index be9086f307b..c37d3e998cd 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -69,6 +69,7 @@ class InterpreterWrapper { std::string TensorName(int i) const; PyObject* TensorType(int i) const; PyObject* TensorSize(int i) const; + PyObject* TensorSizeSignature(int i) const; // Deprecated in favor of TensorQuantizationScales, below. PyObject* TensorQuantization(int i) const; PyObject* TensorQuantizationParameters(int i) const; diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 657cfea1bb8..3965a4ac275 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -261,6 +261,16 @@ class TFLiteConverterBase(object): self.representative_dataset.input_gen, inference_input_type, inference_output_type, allow_float, enable_mlir_quantizer) + def _is_unknown_shapes_allowed(self): + # TODO(b/128319310): Investigate which quantization methods work. + if self._any_optimization_enabled(): + return False + + # Unknown dimensions are only allowed with the new converter. + if not self.experimental_new_converter: + return False + return True + def _get_base_converter_args(self): """Returns the base converter args. @@ -456,19 +466,21 @@ class TFLiteConverterV2(TFLiteConverterBase): config=self._grappler_config(), graph=frozen_func.graph) - # Checks dimensions in input tensor. - for tensor in input_tensors: - # Note that shape_list might be empty for scalar shapes. - shape_list = tensor.shape.as_list() - if None in shape_list[1:]: - raise ValueError( - "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list)) - elif shape_list and shape_list[0] is None: - # Set the batch size to 1 if undefined. - shape = tensor.shape.as_list() - shape[0] = 1 - tensor.set_shape(shape) + if not self._is_unknown_shapes_allowed(): + # Checks dimensions in input tensor. + for tensor in input_tensors: + # Note that shape_list might be empty for scalar shapes. + shape_list = tensor.shape.as_list() + if None in shape_list[1:]: + raise ValueError( + "None is only supported in the 1st dimension. Tensor '{0}' has " + "invalid shape '{1}'.".format( + _get_tensor_name(tensor), shape_list)) + elif shape_list and shape_list[0] is None: + # Set the batch size to 1 if undefined. + shape = tensor.shape.as_list() + shape[0] = 1 + tensor.set_shape(shape) self._validate_quantization() self._validate_representative_dataset() @@ -942,7 +954,7 @@ class TFLiteConverter(TFLiteConverterBase): None value for dimension in input_tensor. """ # Checks dimensions in input tensor. - if self._has_valid_tensors(): + if not self._is_unknown_shapes_allowed() and self._has_valid_tensors(): for tensor in self._input_tensors: shape = tensor.shape if not shape: @@ -1115,6 +1127,20 @@ class TFLiteConverter(TFLiteConverterBase): shape[0] = batch_size tensor.set_shape(shape) + def _is_unknown_shapes_allowed(self): + if not super(TFLiteConverter, self)._is_unknown_shapes_allowed(): + return False + + # `conversion_summary_dir` calls TOCO. Unknown shapes are only supported by + # the MLIR converter. + if self.conversion_summary_dir: + logging.warning( + "`conversion_summary_dir` does not work with unknown shapes. " + "Graphs with unknown shapes might be different than when this flag " + "is disabled.") + return False + return True + @_tf_export(v1=["lite.TocoConverter"]) class TocoConverter(object): diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index f3a29fb97a7..8c1f10af530 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -318,9 +318,11 @@ class FromSessionTest(TestModels, parameterized.TestCase): out_tensor = in_tensor + in_tensor sess = session.Session() - # Test None as shape. + # Test None as shape when dynamic shapes are disabled. Run with TOCO in + # order to invoke shape checking code. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.experimental_new_converter = False with self.assertRaises(ValueError) as error: converter.convert() self.assertEqual('Provide an input shape for input array \'Placeholder\'.', @@ -375,9 +377,11 @@ class FromSessionTest(TestModels, parameterized.TestCase): out_tensor = in_tensor + in_tensor sess = session.Session() - # Test invalid shape. None after 1st dimension. + # Test invalid shape. None after 1st dimension. Run with TOCO in order to + # invoke shape checking code. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.experimental_new_converter = False with self.assertRaises(ValueError) as error: converter.convert() self.assertEqual( @@ -385,6 +389,44 @@ class FromSessionTest(TestModels, parameterized.TestCase): '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', str(error.exception)) + def testSizeNone(self): + with ops.Graph().as_default(): + in_tensor = array_ops.placeholder( + shape=[1, None, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test None after 1st dimension. + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + converter.experimental_new_converter = True + tflite_model = converter.convert() + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 1, 16, 3] == input_details[0]['shape']).all()) + self.assertTrue(([1, -1, 16, + 3] == input_details[0]['shape_signature']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + # Resize tensor and invoke. + interpreter.resize_tensor_input(0, [1, 16, 16, 3]) + interpreter.allocate_tensors() + interpreter.invoke() + + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertTrue(([1, -1, 16, + 3] == input_details[0]['shape_signature']).all()) + + output_details = interpreter.get_output_details() + self.assertFalse(output_details[0]['shape_signature']) + def testBatchSizeValid(self): with ops.Graph().as_default(): in_tensor = array_ops.placeholder( diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 1f0156d6524..bb149399ef9 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -54,12 +54,28 @@ from tensorflow.python.training.tracking import tracking class TestModels(test_util.TensorFlowTestCase, parameterized.TestCase): - def _evaluateTFLiteModel(self, tflite_model, input_data): - """Evaluates the model on the `input_data`.""" + def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None): + """Evaluates the model on the `input_data`. + + Args: + tflite_model: TensorFlow Lite model. + input_data: List of EagerTensor const ops containing the input data for + each input tensor. + input_shapes: List of tuples representing the `shape_signature` and the + new shape of each input tensor that has unknown dimensions. + + Returns: + [np.ndarray] + """ interpreter = Interpreter(model_content=tflite_model) + input_details = interpreter.get_input_details() + if input_shapes: + for idx, (shape_signature, final_shape) in enumerate(input_shapes): + self.assertTrue( + (input_details[idx]['shape_signature'] == shape_signature).all()) + interpreter.resize_tensor_input(idx, final_shape) interpreter.allocate_tensors() - input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() for input_tensor, tensor_data in zip(input_details, input_data): @@ -795,5 +811,62 @@ class GrapplerTest(TestModels): actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data]) np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0]) + +class UnknownShapes(TestModels): + + @test_util.run_v2_only + def testMatMul(self): + input_data = constant_op.constant( + np.array(np.random.random_sample((10, 4)), dtype=np.float32)) + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=[None, 4], dtype=dtypes.float32) + ]) + def model(in_tensor): + shape = array_ops.shape_v2(in_tensor) + fill = array_ops.transpose_v2(array_ops.fill(shape, 1.)) + return math_ops.matmul(fill, in_tensor) + + concrete_func = model.get_concrete_function() + + converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) + converter.experimental_new_converter = True + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = concrete_func(input_data) + actual_value = self._evaluateTFLiteModel( + tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])]) + np.testing.assert_almost_equal( + expected_value.numpy(), actual_value[0], decimal=6) + + def testBatchMatMul(self): + self.skipTest('BatchMatMulV2 ranked tensor check fails.') + input_data_1 = constant_op.constant( + np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32)) + input_data_2 = constant_op.constant( + np.array(np.random.random_sample((1, 2, 256)), dtype=np.float32)) + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=[1, 256, 256], dtype=dtypes.float32), + tensor_spec.TensorSpec(shape=[1, None, 256], dtype=dtypes.float32) + ]) + def model(in_tensor_1, in_tensor_2): + return math_ops.matmul(in_tensor_1, in_tensor_2) + + concrete_func = model.get_concrete_function() + + converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) + converter.experimental_new_converter = True + tflite_model = converter.convert() + + # Check values from converted model. + expected_value = concrete_func(input_data_1, input_data_2) + actual_value = self._evaluateTFLiteModel( + tflite_model, [input_data_1, input_data_2], + input_shapes={1: [1, 2, 256]}) + np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0]) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 0f052b5e5b3..e7d5eaed29f 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -178,6 +178,10 @@ table Tensor { // Parameters to encode a sparse tensor. See the example in // tensorflow/lite/testdata/sparse_tensor.json. sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. } // A list of builtin operators. Builtin operators are slightly faster than custom diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 5daa0782b3a..b91a2f0343d 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -3175,6 +3175,7 @@ struct TensorT : public flatbuffers::NativeTable { std::unique_ptr<QuantizationParametersT> quantization; bool is_variable; std::unique_ptr<SparsityParametersT> sparsity; + std::vector<int32_t> shape_signature; TensorT() : type(TensorType_FLOAT32), buffer(0), @@ -3191,7 +3192,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_NAME = 10, VT_QUANTIZATION = 12, VT_IS_VARIABLE = 14, - VT_SPARSITY = 16 + VT_SPARSITY = 16, + VT_SHAPE_SIGNATURE = 18 }; const flatbuffers::Vector<int32_t> *shape() const { return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE); @@ -3214,6 +3216,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const SparsityParameters *sparsity() const { return GetPointer<const SparsityParameters *>(VT_SPARSITY); } + const flatbuffers::Vector<int32_t> *shape_signature() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -3227,6 +3232,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) && VerifyOffset(verifier, VT_SPARSITY) && verifier.VerifyTable(sparsity()) && + VerifyOffset(verifier, VT_SHAPE_SIGNATURE) && + verifier.VerifyVector(shape_signature()) && verifier.EndTable(); } TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -3258,6 +3265,9 @@ struct TensorBuilder { void add_sparsity(flatbuffers::Offset<SparsityParameters> sparsity) { fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity); } + void add_shape_signature(flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature) { + fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature); + } explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -3278,8 +3288,10 @@ inline flatbuffers::Offset<Tensor> CreateTensor( flatbuffers::Offset<flatbuffers::String> name = 0, flatbuffers::Offset<QuantizationParameters> quantization = 0, bool is_variable = false, - flatbuffers::Offset<SparsityParameters> sparsity = 0) { + flatbuffers::Offset<SparsityParameters> sparsity = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature = 0) { TensorBuilder builder_(_fbb); + builder_.add_shape_signature(shape_signature); builder_.add_sparsity(sparsity); builder_.add_quantization(quantization); builder_.add_name(name); @@ -3298,9 +3310,11 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect( const char *name = nullptr, flatbuffers::Offset<QuantizationParameters> quantization = 0, bool is_variable = false, - flatbuffers::Offset<SparsityParameters> sparsity = 0) { + flatbuffers::Offset<SparsityParameters> sparsity = 0, + const std::vector<int32_t> *shape_signature = nullptr) { auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; + auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0; return tflite::CreateTensor( _fbb, shape__, @@ -3309,7 +3323,8 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect( name__, quantization, is_variable, - sparsity); + sparsity, + shape_signature__); } flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -10275,6 +10290,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); }; { auto _e = is_variable(); _o->is_variable = _e; }; { auto _e = sparsity(); if (_e) _o->sparsity = std::unique_ptr<SparsityParametersT>(_e->UnPack(_resolver)); }; + { auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } }; } inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -10292,6 +10308,7 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder & auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; auto _is_variable = _o->is_variable; auto _sparsity = _o->sparsity ? CreateSparsityParameters(_fbb, _o->sparsity.get(), _rehasher) : 0; + auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0; return tflite::CreateTensor( _fbb, _shape, @@ -10300,7 +10317,8 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder & _name, _quantization, _is_variable, - _sparsity); + _sparsity, + _shape_signature); } inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 4d7fe8c78a8..023e1871d2b 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -391,6 +391,12 @@ typedef struct TfLiteTensor { // This is optional. The field is NULL if a tensor is dense. // WARNING: This is an experimental interface that is subject to change. TfLiteSparsity* sparsity; + + // Optional. Encodes shapes with unknown dimensions with -1. This field is + // only populated when unknown dimensions exist in a read-write tensor (i.e. + // an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and + // `dims_signature` contains [1, -1, -1, 3]). + const TfLiteIntArray* dims_signature; } TfLiteTensor; #ifndef TF_LITE_STATIC_MEMORY