From 89072571ca67f813d60c6ab1f6707eb4ace6be0d Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Mon, 4 May 2020 21:00:52 -0700 Subject: [PATCH] Add `strict` argument to `resize_tensor_input` in TFLite Python API. PiperOrigin-RevId: 309875051 Change-Id: Ia246d6a8b6c087ddcad86b3ead76cd2af833ec59 --- tensorflow/lite/python/interpreter.py | 14 ++++++- .../interpreter_wrapper.cc | 21 ++++++++++- .../interpreter_wrapper/interpreter_wrapper.h | 5 ++- .../interpreter_wrapper_pybind11.cc | 4 +- tensorflow/lite/python/lite_test.py | 37 ++++++++++++++++++- tensorflow/lite/python/lite_v2_test_util.py | 2 +- .../v1/tensorflow.lite.-interpreter.pbtxt | 2 +- .../v2/tensorflow.lite.-interpreter.pbtxt | 2 +- 8 files changed, 76 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 39f303b3a68..ccbba9014c8 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -406,13 +406,23 @@ class Interpreter(object): """ self._interpreter.SetTensor(tensor_index, value) - def resize_tensor_input(self, input_index, tensor_size): + def resize_tensor_input(self, input_index, tensor_size, strict=False): """Resizes an input tensor. + ``` + interpreter = Interpreter(model_content=tflite_model) + interpreter.resize_tensor_input(0, [1, 224, 224, 3], strict=True) + interpreter.allocate_tensors() + interpreter.invoke() + ``` + Args: input_index: Tensor index of input to set. This value can be gotten from the 'index' field in get_input_details. tensor_size: The tensor_shape to resize the input to. + strict: Only unknown dimensions can be resized when `strict` is True. + Unknown dimensions are indicated as `-1` in the `shape_signature` + attribute of a given tensor. (default False) Raises: ValueError: If the interpreter could not resize the input tensor. @@ -421,7 +431,7 @@ class Interpreter(object): # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size # parameter. tensor_size = np.array(tensor_size, dtype=np.int32) - self._interpreter.ResizeInputTensor(input_index, tensor_size) + self._interpreter.ResizeInputTensor(input_index, tensor_size, strict) def get_output_details(self): """Gets model output details. diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index bd78d56172e..844a9827cb6 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -257,7 +257,7 @@ PyObject* InterpreterWrapper::OutputIndices() const { return PyArray_Return(reinterpret_cast(np_array)); } -PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { +PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); std::unique_ptr array_safe( @@ -282,10 +282,27 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { return nullptr; } + PyArray_ENABLEFLAGS(reinterpret_cast(array), + NPY_ARRAY_OWNDATA); + return PyArray_Return(reinterpret_cast(array)); +} + +PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value, + bool strict) { + PyArrayObject* array = + reinterpret_cast(ResizeInputTensorImpl(i, value)); + if (array == nullptr) { + return nullptr; + } + std::vector dims(PyArray_SHAPE(array)[0]); memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); - TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims)); + if (strict) { + TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(i, dims)); + } else { + TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims)); + } Py_RETURN_NONE; } diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index b509c1ca199..2de38d07ed6 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -63,7 +63,7 @@ class InterpreterWrapper { PyObject* InputIndices() const; PyObject* OutputIndices() const; - PyObject* ResizeInputTensor(int i, PyObject* value); + PyObject* ResizeInputTensor(int i, PyObject* value, bool strict); int NumTensors() const; std::string TensorName(int i) const; @@ -110,6 +110,9 @@ class InterpreterWrapper { InterpreterWrapper(); InterpreterWrapper(const InterpreterWrapper& rhs); + // Helper function to resize an input tensor. + PyObject* ResizeInputTensorImpl(int i, PyObject* value); + // The public functions which creates `InterpreterWrapper` should ensure all // these member variables are initialized successfully. Otherwise it should // report the error and return `nullptr`. diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc index b7ab65fdc86..1a61c2aa33b 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc @@ -70,9 +70,9 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) { return tensorflow::PyoOrThrow(self.OutputIndices()); }) .def("ResizeInputTensor", - [](InterpreterWrapper& self, int i, py::handle& value) { + [](InterpreterWrapper& self, int i, py::handle& value, bool strict) { return tensorflow::PyoOrThrow( - self.ResizeInputTensor(i, value.ptr())); + self.ResizeInputTensor(i, value.ptr(), strict)); }) .def("NumTensors", &InterpreterWrapper::NumTensors) .def("TensorName", &InterpreterWrapper::TensorName) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index e327aee1376..3ff9fbc3710 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -450,8 +450,15 @@ class FromSessionTest(TestModels, parameterized.TestCase): 3] == input_details[0]['shape_signature']).all()) self.assertEqual((0., 0.), input_details[0]['quantization']) + # Resize tensor with strict checking. + with self.assertRaises(RuntimeError) as error: + interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) + self.assertIn( + 'ResizeInputTensorStrict only allows mutating unknown dimensions ' + 'identified by -1.', str(error.exception)) + # Resize tensor and invoke. - interpreter.resize_tensor_input(0, [1, 16, 16, 3]) + interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) interpreter.allocate_tensors() interpreter.invoke() @@ -465,6 +472,34 @@ class FromSessionTest(TestModels, parameterized.TestCase): self.assertTrue(([1, -1, 16, 3] == output_details[0]['shape_signature']).all()) + def testResizeTensorInputStrict(self): + # Ensures that resize_tensor_input(strict=True) works as expected. + with ops.Graph().as_default(): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + + # Resize incorrect value. + with self.assertRaises(RuntimeError) as error: + interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) + self.assertIn( + 'ResizeInputTensorStrict only allows mutating unknown dimensions ' + 'identified by -1.', str(error.exception)) + + # Resize correct value. + interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) + interpreter.allocate_tensors() + def testBatchSizeValid(self): with ops.Graph().as_default(): in_tensor = array_ops.placeholder( diff --git a/tensorflow/lite/python/lite_v2_test_util.py b/tensorflow/lite/python/lite_v2_test_util.py index 5ea239f22a2..b3da6a89b86 100644 --- a/tensorflow/lite/python/lite_v2_test_util.py +++ b/tensorflow/lite/python/lite_v2_test_util.py @@ -53,7 +53,7 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase): for idx, (shape_signature, final_shape) in enumerate(input_shapes): self.assertTrue( (input_details[idx]['shape_signature'] == shape_signature).all()) - interpreter.resize_tensor_input(idx, final_shape) + interpreter.resize_tensor_input(idx, final_shape, strict=True) interpreter.allocate_tensors() output_details = interpreter.get_output_details() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt index 5af7412e646..e1c235b5150 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt @@ -36,7 +36,7 @@ tf_class { } member_method { name: "resize_tensor_input" - argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "set_tensor" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt index 5af7412e646..e1c235b5150 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt @@ -36,7 +36,7 @@ tf_class { } member_method { name: "resize_tensor_input" - argspec: "args=[\'self\', \'input_index\', \'tensor_size\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_index\', \'tensor_size\', \'strict\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "set_tensor"